From 35cd2bff6a4b3d8ed4753389cd5f838808569106 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 16:33:53 -0500 Subject: [PATCH 1/5] save migrate --- enzyme/BUILD | 143 ++++ enzyme/Enzyme/MLIR/Dialect/CMakeLists.txt | 1 + enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td | 90 --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 597 +--------------- .../MLIR/Dialect/Impulse/CMakeLists.txt | 30 + enzyme/Enzyme/MLIR/Dialect/Impulse/Dialect.td | 32 + enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h | 31 + .../MLIR/Dialect/Impulse/ImpulseDialect.cpp | 25 + .../MLIR/Dialect/Impulse/ImpulseEnums.td | 110 +++ .../MLIR/Dialect/Impulse/ImpulseOps.cpp | 143 ++++ .../Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td | 655 ++++++++++++++++++ enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 139 ---- .../CoreDialectsAutoDiffImplementations.cpp | 1 + .../CoreDialectsAutoDiffImplementations.h | 1 + .../MLIR/Implementations/EnzymeDerivatives.td | 1 - .../ImpulseAutoDiffOpInterfaceImpl.cpp | 28 + .../Implementations/ImpulseDerivatives.td | 3 + .../Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp | 33 +- enzyme/Enzyme/MLIR/Interfaces/CMakeLists.txt | 3 + enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp | 309 +++++---- enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h | 15 +- .../Enzyme/MLIR/Interfaces/TransformUtils.cpp | 88 +-- .../Enzyme/MLIR/Interfaces/TransformUtils.h | 11 +- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 2 + enzyme/Enzyme/MLIR/Passes/Passes.td | 1 + .../Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp | 163 ++--- enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 2 + enzyme/test/MLIR/ProbProg/exp_transform.mlir | 114 +-- enzyme/test/MLIR/ProbProg/generate.mlir | 58 +- enzyme/test/MLIR/ProbProg/hmc_diag_mass.mlir | 66 +- enzyme/test/MLIR/ProbProg/hmc_kernel.mlir | 82 +-- .../MLIR/ProbProg/mcmc_custom_logpdf.mlir | 54 +- enzyme/test/MLIR/ProbProg/mcmc_sampling.mlir | 56 +- .../test/MLIR/ProbProg/mcmc_strong_zero.mlir | 20 +- enzyme/test/MLIR/ProbProg/mcmc_warmup.mlir | 232 +++---- enzyme/test/MLIR/ProbProg/mh.mlir | 44 +- enzyme/test/MLIR/ProbProg/nuts_kernel.mlir | 84 +-- enzyme/test/MLIR/ProbProg/regenerate.mlir | 58 +- enzyme/test/MLIR/ProbProg/roundtrip.mlir | 12 +- enzyme/test/MLIR/ProbProg/simulate.mlir | 42 +- enzyme/test/MLIR/ProbProg/untraced_call.mlir | 6 +- 41 files changed, 1999 insertions(+), 1586 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Dialect/Impulse/CMakeLists.txt create mode 100644 enzyme/Enzyme/MLIR/Dialect/Impulse/Dialect.td create mode 100644 enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h create mode 100644 enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseDialect.cpp create mode 100644 enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td create mode 100644 enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp create mode 100644 enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td create mode 100644 enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/ImpulseDerivatives.td diff --git a/enzyme/BUILD b/enzyme/BUILD index 8551149ca33c..2384ff2ef2ee 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -441,6 +441,24 @@ td_library( ], ) +td_library( + name = "ImpulseDialectTdFiles", + srcs = [ + "Enzyme/MLIR/Dialect/Impulse/Dialect.td", + "Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td", + ], + includes = [".", "Enzyme/MLIR/Dialect"], + deps = [ + ":EnzymeDialectTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@llvm-project//mlir:ViewLikeInterfaceTdFiles", + ], +) + cc_library( name = "LLVMExtDialect", srcs = [ @@ -621,6 +639,111 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) +gentbl_cc_library( + name = "ImpulseOpsIncGen", + strip_include_prefix = "Enzyme/MLIR", + tbl_outs = [ + ( + ["-gen-op-decls"], + "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.h.inc", + ), + ( + ["-gen-op-defs"], + "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp.inc", + ), + ( + [ + "-gen-dialect-decls", + "-dialect=impulse", + ], + "Enzyme/MLIR/Dialect/Impulse/ImpulseOpsDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=impulse", + ], + "Enzyme/MLIR/Dialect/Impulse/ImpulseOpsDialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td", + deps = [":ImpulseDialectTdFiles"], +) + +gentbl_cc_library( + name = "ImpulseTypesIncGen", + strip_include_prefix = "Enzyme/MLIR", + tbl_outs = [ + ( + ["-gen-typedef-decls"], + "Enzyme/MLIR/Dialect/Impulse/ImpulseOpsTypes.h.inc", + ), + ( + ["-gen-typedef-defs"], + "Enzyme/MLIR/Dialect/Impulse/ImpulseOpsTypes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td", + deps = [":ImpulseDialectTdFiles"], +) + +gentbl_cc_library( + name = "ImpulseEnumsIncGen", + strip_include_prefix = "Enzyme/MLIR", + tbl_outs = [ + ( + ["-gen-enum-decls"], + "Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td", + deps = [":ImpulseDialectTdFiles"], +) + +gentbl_cc_library( + name = "ImpulseAttributesIncGen", + strip_include_prefix = "Enzyme/MLIR", + tbl_outs = [ + ( + [ + "-gen-attrdef-decls", + "-attrdefs-dialect=impulse", + ], + "Enzyme/MLIR/Dialect/Impulse/ImpulseAttributes.h.inc", + ), + ( + [ + "-gen-attrdef-defs", + "-attrdefs-dialect=impulse", + ], + "Enzyme/MLIR/Dialect/Impulse/ImpulseAttributes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td", + deps = [":ImpulseDialectTdFiles"], +) + +gentbl_cc_library( + name = "ImpulseAttributeInterfacesIncGen", + tbl_outs = [ + ( + ["-gen-attr-interface-decls"], + "Enzyme/MLIR/Dialect/Impulse/ImpulseAttributeInterfaces.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td", + deps = [":ImpulseDialectTdFiles"], +) + gentbl_cc_library( name = "EnzymeTypeInterfacesIncGen", strip_include_prefix = "Enzyme", @@ -820,10 +943,23 @@ gentbl_cc_library( deps = [":ImplementationsCommonTdFiles"], ) +gentbl_cc_library( + name = "impulse-derivatives", + strip_include_prefix = "Enzyme/MLIR", + tbl_outs = [( + ["-gen-mlir-derivatives"], + "Enzyme/MLIR/Implementations/ImpulseDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/ImpulseDerivatives.td", + deps = [":ImplementationsCommonTdFiles"], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ "Enzyme/MLIR/Dialect/*.cpp", + "Enzyme/MLIR/Dialect/Impulse/*.cpp", "Enzyme/MLIR/Passes/*.cpp", "Enzyme/MLIR/Interfaces/*.cpp", "Enzyme/MLIR/Analysis/*.cpp", @@ -831,6 +967,7 @@ cc_library( ]), hdrs = glob([ "Enzyme/MLIR/Dialect/*.h", + "Enzyme/MLIR/Dialect/Impulse/*.h", "Enzyme/MLIR/Passes/*.h", "Enzyme/MLIR/Interfaces/*.h", "Enzyme/MLIR/Analysis/*.h", @@ -852,12 +989,18 @@ cc_library( ":EnzymePassesIncGen", ":EnzymeTypeInterfacesIncGen", ":EnzymeTypesIncGen", + ":ImpulseAttributeInterfacesIncGen", + ":ImpulseAttributesIncGen", + ":ImpulseEnumsIncGen", + ":ImpulseOpsIncGen", + ":ImpulseTypesIncGen", ":LLVMExtDialect", ":affine-derivatives", ":arith-derivatives", ":cf-derivatives", ":complex-derivatives", ":enzyme-derivatives", + ":impulse-derivatives", ":func-derivatives", ":linalg-derivatives", ":llvm-derivatives", diff --git a/enzyme/Enzyme/MLIR/Dialect/CMakeLists.txt b/enzyme/Enzyme/MLIR/Dialect/CMakeLists.txt index 6eff7520f63a..2adf717669d9 100644 --- a/enzyme/Enzyme/MLIR/Dialect/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Dialect/CMakeLists.txt @@ -30,4 +30,5 @@ MLIRIR MLIRMemRefDialect ) +add_subdirectory(Impulse) add_subdirectory(LLVMExt) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td index fbfaa781d57b..34161c10f3f7 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td @@ -33,94 +33,4 @@ def Activity : I32EnumAttr<"Activity", def ActivityAttr : EnumAttr; -def RngDistribution : I32EnumAttr<"RngDistribution", - "Random number distribution type", - [ - I32EnumAttrCase<"UNIFORM", 0>, - I32EnumAttrCase<"NORMAL", 1>, - I32EnumAttrCase<"MULTINORMAL", 2>, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::enzyme"; -} - -def RngDistributionAttr : EnumAttr; - -def Transpose : I32EnumAttr<"Transpose", - "Transpose mode for matrix operations", - [ - I32EnumAttrCase<"NO_TRANSPOSE", 0>, - I32EnumAttrCase<"TRANSPOSE", 1>, - I32EnumAttrCase<"ADJOINT", 2>, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::enzyme"; -} - -def TransposeAttr : EnumAttr; - -def SupportKind : I32EnumAttr<"SupportKind", - "Domain where distribution density is non-zero", - [ - I32EnumAttrCase<"REAL", 0>, - I32EnumAttrCase<"POSITIVE", 1>, - I32EnumAttrCase<"UNIT_INTERVAL", 2>, - I32EnumAttrCase<"INTERVAL", 3>, - I32EnumAttrCase<"GREATER_THAN", 4>, - I32EnumAttrCase<"LESS_THAN", 5>, - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::enzyme"; -} - -def SupportKindAttr : EnumAttr; - -def SupportAttr : Enzyme_Attr<"Support", "support"> { - let summary = "Distribution support specification for constraint transforms"; - let description = [{ - Distribution support specification for constraint transforms. - }]; - let parameters = (ins - "SupportKind":$kind, - OptionalParameter<"FloatAttr">:$lower_bound, - OptionalParameter<"FloatAttr">:$upper_bound - ); - let assemblyFormat = "`<` $kind (`,` `lower` `=` $lower_bound^)? (`,` `upper` `=` $upper_bound^)? `>`"; -} - -def HMCConfigAttr : Enzyme_Attr<"HMCConfig", "hmc_config"> { - let summary = "Configuration for HMC inference"; - let description = [{ - HMC-specific parameters for fixed-trajectory Hamiltonian Monte Carlo. - - trajectory_length: Total trajectory length for leapfrog integration. - - adapt_step_size (default: true): Whether to adapt step size during warmup using dual averaging. - - adapt_mass_matrix (default: true): Whether to adapt the mass matrix during warmup using Welford covariance. - }]; - let parameters = (ins - "FloatAttr":$trajectory_length, - DefaultValuedParameter<"bool", "true">:$adapt_step_size, - DefaultValuedParameter<"bool", "true">:$adapt_mass_matrix - ); - let assemblyFormat = "`<` `trajectory_length` `=` $trajectory_length (`,` `adapt_step_size` `=` $adapt_step_size^)? (`,` `adapt_mass_matrix` `=` $adapt_mass_matrix^)? `>`"; -} - -def NUTSConfigAttr : Enzyme_Attr<"NUTSConfig", "nuts_config"> { - let summary = "Configuration for NUTS inference"; - let description = [{ - NUTS-specific parameters for No-U-Turn Sampler with adaptive path length. - - max_tree_depth: Maximum tree depth, controls max leapfrog steps (2^max_tree_depth). - - max_delta_energy: Threshold for trajectory divergence (rejected if energy change exceeds this). - Default: 1000.0 - - adapt_step_size (default: true): Whether to adapt step size during warmup using dual averaging. - - adapt_mass_matrix (default: true): Whether to adapt the mass matrix during warmup using Welford covariance. - }]; - let parameters = (ins - DefaultValuedParameter<"int64_t", "10">:$max_tree_depth, - OptionalParameter<"FloatAttr">:$max_delta_energy, - DefaultValuedParameter<"bool", "true">:$adapt_step_size, - DefaultValuedParameter<"bool", "true">:$adapt_mass_matrix - ); - let assemblyFormat = "`<` (`max_tree_depth` `=` $max_tree_depth^)? (`,` `max_delta_energy` `=` $max_delta_energy^)? (`,` `adapt_step_size` `=` $adapt_step_size^)? (`,` `adapt_mass_matrix` `=` $adapt_mass_matrix^)? `>`"; -} - #endif // ENZYME_ENUMS diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index d3af7a0bb6b8..f28396026faf 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -23,7 +23,6 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" -include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/Arith/IR/ArithBase.td" @@ -228,41 +227,14 @@ def ForwardDiffRegionOp : Enzyme_Op<"fwddiff_region", [AutomaticAllocationScope] } def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["AutoDiffRegionOp", "ForwardDiffRegionOp", "ForLoopOp", "WhileLoopOp", "IfOp"]>]> { - let summary = "Yield values at the end of an autodiff_region, loop, or if ops"; + ParentOneOf<["AutoDiffRegionOp", "ForwardDiffRegionOp"]>]> { + let summary = "Yield values at the end of an autodiff_region"; let arguments = (ins Variadic:$operands); let assemblyFormat = [{ attr-dict ($operands^ `:` type($operands))? }]; } -def ForLoopOp : Enzyme_Op<"for_loop", [AutomaticAllocationScope]> { - let summary = "Counted loop for probabilistic programming"; - let description = [{ - A counted loop operation that iterates from `lowerBound` to `upperBound` - by `step`, carrying `iter_args` through each iteration. The iteration - variable and iter_args are passed to the body region. - }]; - - let arguments = (ins - AnyType:$lowerBound, - AnyType:$upperBound, - AnyType:$step, - Variadic:$initArgs - ); - - let regions = (region SizedRegion<1>:$region); - let results = (outs Variadic:$results); - - let assemblyFormat = [{ - `(` $lowerBound `:` type($lowerBound) `)` - `to` `(` $upperBound `:` type($upperBound) `)` - `step` `(` $step `:` type($step) `)` - (`iter_args` `(` $initArgs^ `:` type($initArgs) `)`)? - `->` type(results) $region attr-dict - }]; -} - def BatchOp : Enzyme_Op<"batch", [DeclareOpInterfaceMethods]> { let summary = "Perform reverse mode AD on a funcop"; @@ -443,490 +415,6 @@ def ExtractOp : Enzyme_Op<"extract"> { }]; } -// Probabilistic programming -def SymbolAttr : Enzyme_Attr<"Symbol", "symbol"> { - let summary = "Symbol associated with a Sample op"; - let description = [{ - Symbol associated with a Sample op. - }]; - let parameters = (ins "uint64_t":$ptr); - let assemblyFormat = "`<` $ptr `>`"; -} - -def AddressAttr : TypedArrayAttrBase; -def AddressArrayAttr : TypedArrayAttrBase; - -def SampleOp : Enzyme_Op<"sample", - [DeclareOpInterfaceMethods]> { - let summary = "Sample from a distribution"; - let description = [{ - Sample from a distribution. By convention, the 0th operand in `inputs` - or `outputs` is the initial RNG state (seed). - }]; - - let arguments = (ins - FlatSymbolRefAttr:$fn, - Variadic:$inputs, - OptionalAttr:$logpdf, - OptionalAttr:$symbol, - OptionalAttr:$support, - DefaultValuedStrAttr:$name - ); - - let results = (outs Variadic:$outputs); - - let assemblyFormat = [{ - $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) - }]; -} - -def UntracedCallOp : Enzyme_Op<"untracedCall"> { - let summary = "Call a probabilistic function without tracing"; - let description = [{ - Call a probabilistic function without tracing. By convention, the 0th operand in `inputs` - or `outputs` is the initial RNG state (seed). - }]; - - let arguments = (ins - FlatSymbolRefAttr:$fn, - Variadic:$inputs, - DefaultValuedStrAttr:$name - ); - - let results = (outs Variadic:$outputs); - - let assemblyFormat = [{ - $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) - }]; -} - -def SimulateOp : Enzyme_Op<"simulate", [DeclareOpInterfaceMethods]> { - let summary = "Simulate a generative function"; - let description = [{ - Simulates a generative function, building a trace tensor containing all - sampled values and computing the accumulated log probability weight. - - The `selection` attribute specifies all sample addresses in order, - determining the trace tensor layout. - - Returns: (trace, weight, rng, retvals...) - - trace: tensor<1 x position_size x f64> - flattened samples - - weight: tensor - accumulated log probability - - rng: updated RNG state - - retvals: original function return values - }]; - - let arguments = (ins - FlatSymbolRefAttr:$fn, - Variadic:$inputs, - AddressArrayAttr:$selection, - DefaultValuedStrAttr:$name - ); - - let results = (outs - AnyRankedTensor:$trace, - AnyRankedTensor:$weight, - Variadic:$outputs - ); - - let assemblyFormat = [{ - $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) - }]; -} - -def GenerateOp : Enzyme_Op<"generate", [DeclareOpInterfaceMethods]> { - let summary = "Constrained generation from a generative function"; - let description = [{ - Generates from a generative function with some addresses constrained. - The constraint tensor contains flattened constrained values in the order - specified by constrained_addresses. - - Returns: (trace, weight, rng, retvals...) - }]; - - let arguments = (ins - FlatSymbolRefAttr:$fn, - Variadic:$inputs, - AnyRankedTensor:$constraint, - AddressArrayAttr:$selection, - AddressArrayAttr:$constrained_addresses, - DefaultValuedStrAttr:$name - ); - - let results = (outs - AnyRankedTensor:$trace, - AnyRankedTensor:$weight, - Variadic:$outputs - ); - - let assemblyFormat = [{ - $fn `(` $inputs `)` `given` $constraint attr-dict `:` functional-type(operands, results) - }]; -} - -def RandomOp : Enzyme_Op<"random"> { - let summary = "Generate random numbers using specified distribution"; - let description = [{ - Generates random numbers using the rng_distribution algorithm and produces - a result tensor. - - If rng_distribution = UNIFORM, then the random numbers are generated following - the uniform distribution over the interval [a, b). If a >= b, the behavior is - undefined. - - If rng_distribution = NORMAL, then the random numbers are generated following - the normal distribution with mean = a and standard deviation = b. If b < 0, - the behavior is undefined. - - If rng_distribution = MULTINORMAL, then the random numbers are generated - following the multivariate normal distribution with mean = a (scalar or vector) - and covariance matrix = b. The parameter b should be a positive definite matrix. - - By convention, the 0th operand in inputs is the initial RNG state and the - 0th operand in results is the updated RNG state. - }]; - - let arguments = (ins - AnyType:$rng_state, - AnyType:$a, - AnyType:$b, - RngDistributionAttr:$rng_distribution - ); - - let results = (outs AnyType:$output_rng_state, AnyType:$result); - - let assemblyFormat = [{ - $rng_state `,` $a `,` $b attr-dict `:` functional-type(operands, results) - }]; -} - -def RandomSplitOp : Enzyme_Op<"randomSplit"> { - let summary = "Split RNG state into multiple independent states"; - let description = [{ - Splits an RNG state into multiple independent RNG states. - Reference: https://github.com/jax-ml/jax/blob/c25e095fcec9678a4ce5f723afce0c6a3c48a5e7/jax/_src/random.py#L281-L294 - }]; - - let arguments = (ins AnyType:$rng_state); - let results = (outs Variadic:$output_rng_states); - - let assemblyFormat = [{ - $rng_state attr-dict `:` functional-type(operands, results) - }]; -} - -def RegenerateOp : Enzyme_Op<"regenerate", [DeclareOpInterfaceMethods]> { - let summary = "Regenerate selected addresses in a trace"; - let description = [{ - Regenerates selected addresses while keeping others fixed. - Used internally by MH. - - Takes explicit old_trace and returns new trace with weight. - - Returns: (new_trace, weight, retvals...) - - new_trace: tensor<1 x position_size x f64> - flattened samples - - weight: tensor - accumulated log probability - - retvals: original function return values - }]; - - let arguments = (ins - FlatSymbolRefAttr:$fn, - Variadic:$inputs, - AnyRankedTensor:$original_trace, - AddressArrayAttr:$selection, - AddressArrayAttr:$regenerate_addresses, - DefaultValuedStrAttr:$name - ); - - let results = (outs - AnyRankedTensor:$new_trace, - AnyRankedTensor:$weight, - Variadic:$outputs - ); - - let assemblyFormat = [{ - $fn `(` $inputs `)` `given` $original_trace attr-dict `:` functional-type(operands, results) - }]; -} - -def MHOp : Enzyme_Op<"mh", [DeclareOpInterfaceMethods]> { - let summary = "Metropolis-Hastings step for probabilistic inference"; - let description = [{ - Performs one MH step: regenerates selected addresses and accepts/rejects - based on weight ratio. - }]; - - let arguments = (ins - FlatSymbolRefAttr:$fn, - AnyRankedTensor:$original_trace, - AnyRankedTensor:$original_weight, - Variadic:$inputs, - AddressArrayAttr:$selection, - AddressArrayAttr:$regenerate_addresses, - DefaultValuedStrAttr:$name - ); - - let results = (outs - AnyRankedTensor:$new_trace, - AnyRankedTensor:$new_weight, - AnyRankedTensor:$accepted, - AnyType:$output_rng - ); - - let assemblyFormat = [{ - $fn `(` $inputs `)` `given` $original_trace `weight` $original_weight attr-dict `:` functional-type(operands, results) - }]; -} - -def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods, AttrSizedOperandSegments]> { - let summary = "MCMC inference for probabilistic programs"; - let description = [{ - Runs MCMC inference on selected addresses. - - Two modes of operation: - 1. Trace-based mode: `fn` and `original_trace` are provided. The model - function with `enzyme.sample` ops defines the density. - 2. Custom logpdf mode: `logpdf_fn` and `initial_position` are provided. - The logpdf function maps position → scalar log-density directly. - - The `selection` attribute determines which addresses to sample via HMC/NUTS. - All sample addresses are included in the trace tensor for consistency. - - Returns: (trace, diagnostics, rng, final_position, final_gradient, - final_potential_energy, final_step_size, final_inverse_mass_matrix) - - trace: tensor - - diagnostics: tensor - placeholder for future expansion - - rng: updated RNG state - - final_position: tensor<1 x position_size x f64> - position after last sample - - final_gradient: tensor<1 x position_size x f64> - gradient at final position - - final_potential_energy: tensor - potential energy at final position - - final_step_size: tensor - adapted step size (after warmup) - - final_inverse_mass_matrix: tensor - adapted inverse mass matrix (after warmup) - }]; - - let arguments = (ins - OptionalAttr:$fn, - Variadic:$inputs, - Optional:$original_trace, - AddressArrayAttr:$selection, - AddressArrayAttr:$all_addresses, - - DefaultValuedAttr:$num_warmup, - DefaultValuedAttr:$num_samples, - DefaultValuedAttr:$thinning, - - // Algorithm-specific MCMC parameters - Optional:$inverse_mass_matrix, - Optional:$step_size, - - // Algorithm-specific configurations - OptionalAttr:$hmc_config, - OptionalAttr:$nuts_config, - - // Custom logpdf mode - OptionalAttr:$logpdf_fn, - Optional:$initial_position, - - Optional:$initial_gradient, - Optional:$initial_potential_energy, - - OptionalAttr:$autodiff_attrs, - DefaultValuedStrAttr:$name - ); - - let results = (outs - AnyRankedTensor:$trace, - AnyRankedTensor:$diagnostics, - AnyType:$output_rng_state, - AnyRankedTensor:$final_position, - AnyRankedTensor:$final_gradient, - AnyRankedTensor:$final_potential_energy, - AnyRankedTensor:$final_step_size, - AnyRankedTensor:$final_inverse_mass_matrix - ); - - let assemblyFormat = [{ - ($fn^)? - `(` $inputs `)` - (`given` $original_trace^)? - (`inverse_mass_matrix` `=` $inverse_mass_matrix^)? - (`step_size` `=` $step_size^)? - (`logpdf_fn` `=` $logpdf_fn^)? - (`initial_position` `=` $initial_position^)? - (`initial_gradient` `=` $initial_gradient^)? - (`initial_potential_energy` `=` $initial_potential_energy^)? - attr-dict `:` functional-type(operands, results) - }]; - - let hasVerifier = 1; -} - -def DotOp : Enzyme_Op<"dot", [Pure]> { - let summary = "Computes a general dot product operation"; - let description = [{ - Computes a general dot product operation. To be lowered to `stablehlo.dot_general`. - }]; - - let arguments = (ins - AnyRankedTensor:$lhs, - AnyRankedTensor:$rhs, - DenseI64ArrayAttr:$lhs_batching_dimensions, - DenseI64ArrayAttr:$rhs_batching_dimensions, - DenseI64ArrayAttr:$lhs_contracting_dimensions, - DenseI64ArrayAttr:$rhs_contracting_dimensions - ); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` functional-type(operands, results) - }]; -} - -def CholeskyOp : Enzyme_Op<"cholesky", [Pure]> { - let summary = "Compute Cholesky decomposition of a symmetric positive definite matrix"; - let description = [{ - Computes the Cholesky decomposition of a symmetric positive definite matrix A. - Returns L such that A = L @ L^T (if lower=true) or A = U^T @ U (if lower=false). - }]; - - let arguments = (ins - AnyRankedTensor:$input, - DefaultValuedAttr:$lower - ); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - $input attr-dict `:` functional-type(operands, results) - }]; -} - -def TriangularSolveOp : Enzyme_Op<"triangular_solve", [Pure]> { - let summary = "Solve a triangular linear system"; - let description = [{ - Solves a system of linear equations with a triangular coefficient matrix. - If left_side=true, solves op(A) @ X = B for X. - If left_side=false, solves X @ op(A) = B for X. - op(A) is determined by transpose_a: NO_TRANSPOSE, TRANSPOSE, or ADJOINT. - }]; - - let arguments = (ins - AnyRankedTensor:$a, - AnyRankedTensor:$b, - DefaultValuedAttr:$left_side, - DefaultValuedAttr:$lower, - DefaultValuedAttr:$unit_diagonal, - DefaultValuedAttr:$transpose_a - ); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - $a `,` $b attr-dict `:` functional-type(operands, results) - }]; -} - -def SelectOp : Enzyme_Op<"select", [Pure]> { - let summary = "Exteneded select operation"; - let description = [{ - Extended select operation that supports: - - `tensor` conditions with differently-sized operands - - standard cases supported by `arith.select` - }]; - - let arguments = (ins - AnyRankedTensor:$condition, - AnyType:$true_value, - AnyType:$false_value - ); - - let results = (outs AnyType:$result); - - let assemblyFormat = [{ - $condition `,` $true_value `,` $false_value attr-dict `:` functional-type(operands, results) - }]; -} - -def ReshapeOp : Enzyme_Op<"reshape", [Pure]> { - let summary = "Reshape a tensor to a new static shape"; - let arguments = (ins AnyRankedTensor:$input); - let results = (outs AnyRankedTensor:$result); - let assemblyFormat = [{ - $input attr-dict `:` functional-type($input, $result) - }]; -} - -def SliceOp : Enzyme_Op<"slice", [Pure]> { - let summary = "Extract a static slice from a tensor"; - let description = [{ - Extract a static slice from a tensor. - }]; - - let arguments = (ins - AnyRankedTensor:$operand, - DenseI64ArrayAttr:$start_indices, - DenseI64ArrayAttr:$limit_indices, - DenseI64ArrayAttr:$strides - ); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - $operand attr-dict `:` functional-type($operand, $result) - }]; -} - -def DynamicSliceOp : Enzyme_Op<"dynamic_slice", [Pure]> { - let summary = "Extract a slice from a tensor at dynamic start indices"; - let description = [{ - Extract a slice from a tensor at dynamic start indices. - }]; - - let arguments = (ins - AnyRankedTensor:$operand, - Variadic:$start_indices, - DenseI64ArrayAttr:$slice_sizes - ); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - $operand `,` $start_indices attr-dict `:` functional-type(operands, results) - }]; -} - -def DynamicUpdateSliceOp : Enzyme_Op<"dynamic_update_slice", [Pure]> { - let summary = "Update a slice in a tensor at dynamic start indices"; - let description = [{ - Update a slice in a tensor at dynamic start indices. - }]; - - let arguments = (ins - AnyRankedTensor:$operand, - AnyRankedTensor:$update, - Variadic:$start_indices - ); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - $operand `,` $update `,` $start_indices attr-dict `:` functional-type(operands, results) - }]; -} - -def DumpOp : Enzyme_Op<"dump"> { - let summary = "Debug operation to dump a tensor value at runtime"; - let description = [{ - Debug operation that dumps a tensor value with a label. - }]; - - let arguments = (ins - AnyType:$value, - StrAttr:$label - ); - - let results = (outs AnyType:$output); - - let assemblyFormat = [{ - $value attr-dict `:` functional-type($value, results) - }]; -} - def AffineAtomicRMWOp : Enzyme_Op<"affine_atomic_rmw"> { let summary = "affine atomic rmw operation"; let description = [{ @@ -947,85 +435,4 @@ def AffineAtomicRMWOp : Enzyme_Op<"affine_atomic_rmw"> { }]; } -def WhileLoopOp : Enzyme_Op<"while_loop", [AutomaticAllocationScope]> { - let summary = "While loop with condition"; - let description = [{ - A while loop operation that continues iterating as long as the condition - evaluates to true. Intended to be lowered to `stablehlo.while`. - }]; - - let arguments = (ins Variadic:$initArgs); - let regions = (region SizedRegion<1>:$conditionRegion, - SizedRegion<1>:$bodyRegion); - let results = (outs Variadic:$results); - - let assemblyFormat = [{ - `(` $initArgs `:` type($initArgs) `)` - `->` type(results) - `condition` $conditionRegion - `body` $bodyRegion - attr-dict - }]; -} - -def IfOp : Enzyme_Op<"if", [AutomaticAllocationScope, RecursiveMemoryEffects]> { - let summary = "Conditional execution with two branches"; - let description = [{ - A conditional operation that executes exactly one of two branches based on a - boolean predicate. - }]; - - let arguments = (ins AnyType:$predicate); - let regions = (region SizedRegion<1>:$trueBranch, - SizedRegion<1>:$falseBranch); - let results = (outs Variadic:$results); - - let assemblyFormat = [{ - `(` $predicate `)` `(` $trueBranch `,` $falseBranch `)` - attr-dict `:` functional-type($predicate, results) - }]; -} - -def LogAddExpOp : Enzyme_Op<"log_add_exp", [Pure]> { - let summary = "Computes log(exp(x) + exp(y))"; - let description = [{ - Computes log(exp(x) + exp(y)). - }]; - - let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` functional-type(operands, results) - }]; -} - -def LogisticOp : Enzyme_Op<"logistic", [Pure]> { - let summary = "Computes logistic (sigmoid) function: 1 / (1 + exp(-x))"; - let description = [{ - Computes the logistic (sigmoid) function: 1 / (1 + exp(-x)). - }]; - - let arguments = (ins AnyRankedTensor:$operand); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - $operand attr-dict `:` functional-type(operands, results) - }]; -} - -def PopcountOp : Enzyme_Op<"popcount", [Pure, SameOperandsAndResultElementType, SameOperandsAndResultShape]> { - let summary = "Computes population count"; - let description = [{ - Returns the number of 1-bits elementwise. - }]; - - let arguments = (ins AnyRankedTensor:$operand); - let results = (outs AnyRankedTensor:$result); - - let assemblyFormat = [{ - $operand attr-dict `:` functional-type(operands, results) - }]; -} - #endif // ENZYME_OPS diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/CMakeLists.txt b/enzyme/Enzyme/MLIR/Dialect/Impulse/CMakeLists.txt new file mode 100644 index 000000000000..728df4d51c01 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/CMakeLists.txt @@ -0,0 +1,30 @@ +add_mlir_dialect(ImpulseOps impulse) +add_mlir_doc(ImpulseDialect -gen-dialect-doc ImpulseDialect Enzyme/) +add_mlir_doc(ImpulseOps -gen-op-doc ImpulseOps Enzyme/) + +set(LLVM_TARGET_DEFINITIONS ImpulseOps.td) +mlir_tablegen(ImpulseAttributeInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(ImpulseAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=impulse) +mlir_tablegen(ImpulseAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=impulse) +add_public_tablegen_target(MLIRImpulseAttributesIncGen) + +set(LLVM_TARGET_DEFINITIONS ImpulseEnums.td) +mlir_tablegen(ImpulseEnums.h.inc -gen-enum-decls) +mlir_tablegen(ImpulseEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRImpulseEnumsIncGen) + +add_mlir_dialect_library(MLIRImpulse + ImpulseDialect.cpp + ImpulseOps.cpp + + ADDITIONAL_HEADER_DIRS + + DEPENDS + MLIRImpulseOpsIncGen + MLIRImpulseEnumsIncGen + MLIRImpulseAttributesIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIREnzyme +) diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/Dialect.td b/enzyme/Enzyme/MLIR/Dialect/Impulse/Dialect.td new file mode 100644 index 000000000000..5c3dc8182012 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/Dialect.td @@ -0,0 +1,32 @@ +#ifndef IMPULSE_DIALECT_TD +#define IMPULSE_DIALECT_TD + +include "mlir/IR/DialectBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/Traits.td" +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Impulse dialect definition. +//===----------------------------------------------------------------------===// + +def Impulse_Dialect : Dialect { + let name = "impulse"; + let summary = "A probabilistic programming dialect"; + let cppNamespace = "::mlir::impulse"; + let dependentDialects = [ + "mlir::enzyme::EnzymeDialect", + ]; + let useDefaultAttributePrinterParser = 1; +} + +//===----------------------------------------------------------------------===// +// Base Impulse operation definition. +//===----------------------------------------------------------------------===// + +class Impulse_Op traits = []> + : Op; + +class Impulse_Type : TypeDef; + +#endif // IMPULSE_DIALECT_TD diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h b/enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h new file mode 100644 index 000000000000..65a1309e55d7 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h @@ -0,0 +1,31 @@ +//===- Impulse.h - Impulse dialect --------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_IMPULSE_H +#define ENZYME_IMPULSE_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Bytecode/BytecodeOpInterface.h" + +#include "Dialect/Impulse/ImpulseEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "Dialect/Impulse/ImpulseAttributes.h.inc" + +#include "Dialect/Impulse/ImpulseOpsDialect.h.inc" + +#define GET_OP_CLASSES +#include "Dialect/Impulse/ImpulseOps.h.inc" + +#endif // ENZYME_IMPULSE_H diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseDialect.cpp b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseDialect.cpp new file mode 100644 index 000000000000..cdb8131e8da2 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseDialect.cpp @@ -0,0 +1,25 @@ +#include "Dialect/Impulse/Impulse.h" +#include "Dialect/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Builders.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "Dialect/Impulse/ImpulseEnums.cpp.inc" +#include "Dialect/Impulse/ImpulseOpsDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "Dialect/Impulse/ImpulseOps.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "Dialect/Impulse/ImpulseAttributes.cpp.inc" + +void mlir::impulse::ImpulseDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "Dialect/Impulse/ImpulseOps.cpp.inc" + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/Impulse/ImpulseAttributes.cpp.inc" + >(); +} diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td new file mode 100644 index 000000000000..565ef66c8b4f --- /dev/null +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td @@ -0,0 +1,110 @@ +//===- ImpulseEnums.td - Impulse dialect enums ---------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef IMPULSE_ENUMS +#define IMPULSE_ENUMS + +include "mlir/IR/EnumAttr.td" +include "Impulse/Dialect.td" + +class Impulse_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +def RngDistribution : I32EnumAttr<"RngDistribution", + "Random number distribution type", + [ + I32EnumAttrCase<"UNIFORM", 0>, + I32EnumAttrCase<"NORMAL", 1>, + I32EnumAttrCase<"MULTINORMAL", 2>, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::impulse"; +} + +def RngDistributionAttr : EnumAttr; + +def Transpose : I32EnumAttr<"Transpose", + "Transpose mode for matrix operations", + [ + I32EnumAttrCase<"NO_TRANSPOSE", 0>, + I32EnumAttrCase<"TRANSPOSE", 1>, + I32EnumAttrCase<"ADJOINT", 2>, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::impulse"; +} + +def TransposeAttr : EnumAttr; + +def SupportKind : I32EnumAttr<"SupportKind", + "Domain where distribution density is non-zero", + [ + I32EnumAttrCase<"REAL", 0>, + I32EnumAttrCase<"POSITIVE", 1>, + I32EnumAttrCase<"UNIT_INTERVAL", 2>, + I32EnumAttrCase<"INTERVAL", 3>, + I32EnumAttrCase<"GREATER_THAN", 4>, + I32EnumAttrCase<"LESS_THAN", 5>, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::impulse"; +} + +def SupportKindAttr : EnumAttr; + +def SupportAttr : Impulse_Attr<"Support", "support"> { + let summary = "Distribution support specification for constraint transforms"; + let description = [{ + Distribution support specification for constraint transforms. + }]; + let parameters = (ins + "SupportKind":$kind, + OptionalParameter<"FloatAttr">:$lower_bound, + OptionalParameter<"FloatAttr">:$upper_bound + ); + let assemblyFormat = "`<` $kind (`,` `lower` `=` $lower_bound^)? (`,` `upper` `=` $upper_bound^)? `>`"; +} + +def HMCConfigAttr : Impulse_Attr<"HMCConfig", "hmc_config"> { + let summary = "Configuration for HMC inference"; + let description = [{ + HMC-specific parameters for fixed-trajectory Hamiltonian Monte Carlo. + - trajectory_length: Total trajectory length for leapfrog integration. + - adapt_step_size (default: true): Whether to adapt step size during warmup using dual averaging. + - adapt_mass_matrix (default: true): Whether to adapt the mass matrix during warmup using Welford covariance. + }]; + let parameters = (ins + "FloatAttr":$trajectory_length, + DefaultValuedParameter<"bool", "true">:$adapt_step_size, + DefaultValuedParameter<"bool", "true">:$adapt_mass_matrix + ); + let assemblyFormat = "`<` `trajectory_length` `=` $trajectory_length (`,` `adapt_step_size` `=` $adapt_step_size^)? (`,` `adapt_mass_matrix` `=` $adapt_mass_matrix^)? `>`"; +} + +def NUTSConfigAttr : Impulse_Attr<"NUTSConfig", "nuts_config"> { + let summary = "Configuration for NUTS inference"; + let description = [{ + NUTS-specific parameters for No-U-Turn Sampler with adaptive path length. + - max_tree_depth: Maximum tree depth, controls max leapfrog steps (2^max_tree_depth). + - max_delta_energy: Threshold for trajectory divergence (rejected if energy change exceeds this). + Default: 1000.0 + - adapt_step_size (default: true): Whether to adapt step size during warmup using dual averaging. + - adapt_mass_matrix (default: true): Whether to adapt the mass matrix during warmup using Welford covariance. + }]; + let parameters = (ins + DefaultValuedParameter<"int64_t", "10">:$max_tree_depth, + OptionalParameter<"FloatAttr">:$max_delta_energy, + DefaultValuedParameter<"bool", "true">:$adapt_step_size, + DefaultValuedParameter<"bool", "true">:$adapt_mass_matrix + ); + let assemblyFormat = "`<` (`max_tree_depth` `=` $max_tree_depth^)? (`,` `max_delta_energy` `=` $max_delta_energy^)? (`,` `adapt_step_size` `=` $adapt_step_size^)? (`,` `adapt_mass_matrix` `=` $adapt_mass_matrix^)? `>`"; +} + +#endif // IMPULSE_ENUMS diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp new file mode 100644 index 000000000000..2ea8718344a2 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp @@ -0,0 +1,143 @@ +//===- ImpulseOps.cpp - Impulse dialect ops ----------------------*- C++ -*-===// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Dialect/Impulse/Impulse.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" + +using namespace mlir; +using namespace mlir::impulse; + +//===----------------------------------------------------------------------===// +// SampleOp +//===----------------------------------------------------------------------===// + +LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid global funcOp"; + + if (getLogpdfAttr()) { + auto global = symbolTable.lookupNearestSymbolFrom( + *this, getLogpdfAttr()); + if (!global) + return emitOpError("'") + << getLogpdf().value() << "' does not reference a valid global " + << "funcOp"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// GenerateOp +//===----------------------------------------------------------------------===// + +LogicalResult GenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid global funcOp"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SimulateOp +//===----------------------------------------------------------------------===// + +LogicalResult SimulateOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid global funcOp"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// RegenerateOp +//===----------------------------------------------------------------------===// + +LogicalResult +RegenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid global funcOp"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// MHOp +//===----------------------------------------------------------------------===// + +LogicalResult MHOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid global funcOp"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// InferOp +//===----------------------------------------------------------------------===// + +LogicalResult InferOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + if (auto fnAttr = getFnAttr()) { + auto global = + symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!global) + return emitOpError("'") + << getFn().value() << "' does not reference a valid global funcOp"; + } + + if (auto logpdfAttr = getLogpdfFnAttr()) { + auto global = + symbolTable.lookupNearestSymbolFrom(*this, logpdfAttr); + if (!global) + return emitOpError("'") << logpdfAttr.getValue() + << "' does not reference a valid global funcOp"; + } + + return success(); +} + +LogicalResult InferOp::verify() { + bool hasHMC = getHmcConfig().has_value(); + bool hasNUTS = getNutsConfig().has_value(); + + if (hasHMC + hasNUTS != 1) { + return emitOpError( + "Exactly one of hmc_config or nuts_config must be specified"); + } + + if (!getFnAttr() && !getLogpdfFnAttr()) { + return emitOpError("one of `fn` or `logpdf_fn` must be specified"); + } + + if (getFnAttr() && getLogpdfFnAttr()) { + return emitOpError("specifying both `fn` and `logpdf_fn` is unsupported"); + } + + if (getLogpdfFnAttr() && !getInitialPosition()) { + return emitOpError( + "custom logpdf mode requires `initial_position` to be provided"); + } + + return success(); +} diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td new file mode 100644 index 000000000000..9d9944f4124e --- /dev/null +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td @@ -0,0 +1,655 @@ +//===- ImpulseOps.td - Impulse dialect ops -----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef IMPULSE_OPS +#define IMPULSE_OPS + +include "Impulse/ImpulseEnums.td" +include "Impulse/Dialect.td" + +include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" + +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// Attributes +//===----------------------------------------------------------------------===// + +def SymbolAttr : Impulse_Attr<"Symbol", "symbol"> { + let summary = "Symbol associated with a Sample op"; + let description = [{ + Symbol associated with a Sample op. + }]; + let parameters = (ins "uint64_t":$ptr); + let assemblyFormat = "`<` $ptr `>`"; +} + +def AddressAttr : TypedArrayAttrBase; +def AddressArrayAttr : TypedArrayAttrBase; + +//===----------------------------------------------------------------------===// +// Probabilistic programming ops +//===----------------------------------------------------------------------===// + +def SampleOp : Impulse_Op<"sample", + [DeclareOpInterfaceMethods]> { + let summary = "Sample from a distribution"; + let description = [{ + Sample from a distribution. By convention, the 0th operand in `inputs` + or `outputs` is the initial RNG state (seed). + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + Variadic:$inputs, + OptionalAttr:$logpdf, + OptionalAttr:$symbol, + OptionalAttr:$support, + DefaultValuedStrAttr:$name + ); + + let results = (outs Variadic:$outputs); + + let assemblyFormat = [{ + $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; +} + +def UntracedCallOp : Impulse_Op<"untracedCall"> { + let summary = "Call a probabilistic function without tracing"; + let description = [{ + Call a probabilistic function without tracing. By convention, the 0th operand in `inputs` + or `outputs` is the initial RNG state (seed). + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + Variadic:$inputs, + DefaultValuedStrAttr:$name + ); + + let results = (outs Variadic:$outputs); + + let assemblyFormat = [{ + $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; +} + +def SimulateOp : Impulse_Op<"simulate", [DeclareOpInterfaceMethods]> { + let summary = "Simulate a generative function"; + let description = [{ + Simulates a generative function, building a trace tensor containing all + sampled values and computing the accumulated log probability weight. + + The `selection` attribute specifies all sample addresses in order, + determining the trace tensor layout. + + Returns: (trace, weight, rng, retvals...) + - trace: tensor<1 x position_size x f64> - flattened samples + - weight: tensor - accumulated log probability + - rng: updated RNG state + - retvals: original function return values + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + Variadic:$inputs, + AddressArrayAttr:$selection, + DefaultValuedStrAttr:$name + ); + + let results = (outs + AnyRankedTensor:$trace, + AnyRankedTensor:$weight, + Variadic:$outputs + ); + + let assemblyFormat = [{ + $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; +} + +def GenerateOp : Impulse_Op<"generate", [DeclareOpInterfaceMethods]> { + let summary = "Constrained generation from a generative function"; + let description = [{ + Generates from a generative function with some addresses constrained. + The constraint tensor contains flattened constrained values in the order + specified by constrained_addresses. + + Returns: (trace, weight, rng, retvals...) + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + Variadic:$inputs, + AnyRankedTensor:$constraint, + AddressArrayAttr:$selection, + AddressArrayAttr:$constrained_addresses, + DefaultValuedStrAttr:$name + ); + + let results = (outs + AnyRankedTensor:$trace, + AnyRankedTensor:$weight, + Variadic:$outputs + ); + + let assemblyFormat = [{ + $fn `(` $inputs `)` `given` $constraint attr-dict `:` functional-type(operands, results) + }]; +} + +def RegenerateOp : Impulse_Op<"regenerate", [DeclareOpInterfaceMethods]> { + let summary = "Regenerate selected addresses in a trace"; + let description = [{ + Regenerates selected addresses while keeping others fixed. + Used internally by MH. + + Takes explicit old_trace and returns new trace with weight. + + Returns: (new_trace, weight, retvals...) + - new_trace: tensor<1 x position_size x f64> - flattened samples + - weight: tensor - accumulated log probability + - retvals: original function return values + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + Variadic:$inputs, + AnyRankedTensor:$original_trace, + AddressArrayAttr:$selection, + AddressArrayAttr:$regenerate_addresses, + DefaultValuedStrAttr:$name + ); + + let results = (outs + AnyRankedTensor:$new_trace, + AnyRankedTensor:$weight, + Variadic:$outputs + ); + + let assemblyFormat = [{ + $fn `(` $inputs `)` `given` $original_trace attr-dict `:` functional-type(operands, results) + }]; +} + +def MHOp : Impulse_Op<"mh", [DeclareOpInterfaceMethods]> { + let summary = "Metropolis-Hastings step for probabilistic inference"; + let description = [{ + Performs one MH step: regenerates selected addresses and accepts/rejects + based on weight ratio. + }]; + + let arguments = (ins + FlatSymbolRefAttr:$fn, + AnyRankedTensor:$original_trace, + AnyRankedTensor:$original_weight, + Variadic:$inputs, + AddressArrayAttr:$selection, + AddressArrayAttr:$regenerate_addresses, + DefaultValuedStrAttr:$name + ); + + let results = (outs + AnyRankedTensor:$new_trace, + AnyRankedTensor:$new_weight, + AnyRankedTensor:$accepted, + AnyType:$output_rng + ); + + let assemblyFormat = [{ + $fn `(` $inputs `)` `given` $original_trace `weight` $original_weight attr-dict `:` functional-type(operands, results) + }]; +} + +def InferOp : Impulse_Op<"infer", [DeclareOpInterfaceMethods, AttrSizedOperandSegments]> { + let summary = "MCMC inference for probabilistic programs"; + let description = [{ + Runs MCMC inference on selected addresses. + + Two modes of operation: + 1. Trace-based mode: `fn` and `original_trace` are provided. The model + function with `impulse.sample` ops defines the density. + 2. Custom logpdf mode: `logpdf_fn` and `initial_position` are provided. + The logpdf function maps position -> scalar log-density directly. + + The `selection` attribute determines which addresses to sample via HMC/NUTS. + All sample addresses are included in the trace tensor for consistency. + + Returns: (trace, diagnostics, rng, final_position, final_gradient, + final_potential_energy, final_step_size, final_inverse_mass_matrix) + - trace: tensor + - diagnostics: tensor - placeholder for future expansion + - rng: updated RNG state + - final_position: tensor<1 x position_size x f64> - position after last sample + - final_gradient: tensor<1 x position_size x f64> - gradient at final position + - final_potential_energy: tensor - potential energy at final position + - final_step_size: tensor - adapted step size (after warmup) + - final_inverse_mass_matrix: tensor - adapted inverse mass matrix (after warmup) + }]; + + let arguments = (ins + OptionalAttr:$fn, + Variadic:$inputs, + Optional:$original_trace, + AddressArrayAttr:$selection, + AddressArrayAttr:$all_addresses, + + DefaultValuedAttr:$num_warmup, + DefaultValuedAttr:$num_samples, + DefaultValuedAttr:$thinning, + + // Algorithm-specific MCMC parameters + Optional:$inverse_mass_matrix, + Optional:$step_size, + + // Algorithm-specific configurations + OptionalAttr:$hmc_config, + OptionalAttr:$nuts_config, + + // Custom logpdf mode + OptionalAttr:$logpdf_fn, + Optional:$initial_position, + + Optional:$initial_gradient, + Optional:$initial_potential_energy, + + OptionalAttr:$autodiff_attrs, + DefaultValuedStrAttr:$name + ); + + let results = (outs + AnyRankedTensor:$trace, + AnyRankedTensor:$diagnostics, + AnyType:$output_rng_state, + AnyRankedTensor:$final_position, + AnyRankedTensor:$final_gradient, + AnyRankedTensor:$final_potential_energy, + AnyRankedTensor:$final_step_size, + AnyRankedTensor:$final_inverse_mass_matrix + ); + + let assemblyFormat = [{ + ($fn^)? + `(` $inputs `)` + (`given` $original_trace^)? + (`inverse_mass_matrix` `=` $inverse_mass_matrix^)? + (`step_size` `=` $step_size^)? + (`logpdf_fn` `=` $logpdf_fn^)? + (`initial_position` `=` $initial_position^)? + (`initial_gradient` `=` $initial_gradient^)? + (`initial_potential_energy` `=` $initial_potential_energy^)? + attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// RNG ops +//===----------------------------------------------------------------------===// + +def RandomOp : Impulse_Op<"random"> { + let summary = "Generate random numbers using specified distribution"; + let description = [{ + Generates random numbers using the rng_distribution algorithm and produces + a result tensor. + + If rng_distribution = UNIFORM, then the random numbers are generated following + the uniform distribution over the interval [a, b). If a >= b, the behavior is + undefined. + + If rng_distribution = NORMAL, then the random numbers are generated following + the normal distribution with mean = a and standard deviation = b. If b < 0, + the behavior is undefined. + + If rng_distribution = MULTINORMAL, then the random numbers are generated + following the multivariate normal distribution with mean = a (scalar or vector) + and covariance matrix = b. The parameter b should be a positive definite matrix. + + By convention, the 0th operand in inputs is the initial RNG state and the + 0th operand in results is the updated RNG state. + }]; + + let arguments = (ins + AnyType:$rng_state, + AnyType:$a, + AnyType:$b, + RngDistributionAttr:$rng_distribution + ); + + let results = (outs AnyType:$output_rng_state, AnyType:$result); + + let assemblyFormat = [{ + $rng_state `,` $a `,` $b attr-dict `:` functional-type(operands, results) + }]; +} + +def RandomSplitOp : Impulse_Op<"randomSplit"> { + let summary = "Split RNG state into multiple independent states"; + let description = [{ + Splits an RNG state into multiple independent RNG states. + Reference: https://github.com/jax-ml/jax/blob/c25e095fcec9678a4ce5f723afce0c6a3c48a5e7/jax/_src/random.py#L281-L294 + }]; + + let arguments = (ins AnyType:$rng_state); + let results = (outs Variadic:$output_rng_states); + + let assemblyFormat = [{ + $rng_state attr-dict `:` functional-type(operands, results) + }]; +} + +//===----------------------------------------------------------------------===// +// Control flow ops +//===----------------------------------------------------------------------===// + +def ForOp : Impulse_Op<"for", [AutomaticAllocationScope]> { + let summary = "Counted loop for probabilistic programming"; + let description = [{ + A counted loop operation that iterates from `lowerBound` to `upperBound` + by `step`, carrying `iter_args` through each iteration. The iteration + variable and iter_args are passed to the body region. + }]; + + let arguments = (ins + AnyType:$lowerBound, + AnyType:$upperBound, + AnyType:$step, + Variadic:$initArgs + ); + + let regions = (region SizedRegion<1>:$region); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + `(` $lowerBound `:` type($lowerBound) `)` + `to` `(` $upperBound `:` type($upperBound) `)` + `step` `(` $step `:` type($step) `)` + (`iter_args` `(` $initArgs^ `:` type($initArgs) `)`)? + `->` type(results) $region attr-dict + }]; +} + +def WhileOp : Impulse_Op<"while", [AutomaticAllocationScope]> { + let summary = "While loop with condition"; + let description = [{ + A while loop operation that continues iterating as long as the condition + evaluates to true. Intended to be lowered to `stablehlo.while`. + }]; + + let arguments = (ins Variadic:$initArgs); + let regions = (region SizedRegion<1>:$conditionRegion, + SizedRegion<1>:$bodyRegion); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + `(` $initArgs `:` type($initArgs) `)` + `->` type(results) + `condition` $conditionRegion + `body` $bodyRegion + attr-dict + }]; +} + +def IfOp : Impulse_Op<"if", [AutomaticAllocationScope, RecursiveMemoryEffects]> { + let summary = "Conditional execution with two branches"; + let description = [{ + A conditional operation that executes exactly one of two branches based on a + boolean predicate. + }]; + + let arguments = (ins AnyType:$predicate); + let regions = (region SizedRegion<1>:$trueBranch, + SizedRegion<1>:$falseBranch); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + `(` $predicate `)` `(` $trueBranch `,` $falseBranch `)` + attr-dict `:` functional-type($predicate, results) + }]; +} + +def YieldOp : Impulse_Op<"yield", [Pure, ReturnLike, Terminator, + ParentOneOf<["ForOp", "WhileOp", "IfOp"]>]> { + let summary = "Yield values at the end of a loop or if op"; + let arguments = (ins Variadic:$operands); + let assemblyFormat = [{ + attr-dict ($operands^ `:` type($operands))? + }]; +} + +//===----------------------------------------------------------------------===// +// Linear algebra ops +//===----------------------------------------------------------------------===// + +def DotOp : Impulse_Op<"dot", [Pure]> { + let summary = "Computes a general dot product operation"; + let description = [{ + Computes a general dot product operation. To be lowered to `stablehlo.dot_general`. + }]; + + let arguments = (ins + AnyRankedTensor:$lhs, + AnyRankedTensor:$rhs, + DenseI64ArrayAttr:$lhs_batching_dimensions, + DenseI64ArrayAttr:$rhs_batching_dimensions, + DenseI64ArrayAttr:$lhs_contracting_dimensions, + DenseI64ArrayAttr:$rhs_contracting_dimensions + ); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` functional-type(operands, results) + }]; +} + +def CholeskyOp : Impulse_Op<"cholesky", [Pure]> { + let summary = "Compute Cholesky decomposition of a symmetric positive definite matrix"; + let description = [{ + Computes the Cholesky decomposition of a symmetric positive definite matrix A. + Returns L such that A = L @ L^T (if lower=true) or A = U^T @ U (if lower=false). + }]; + + let arguments = (ins + AnyRankedTensor:$input, + DefaultValuedAttr:$lower + ); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $input attr-dict `:` functional-type(operands, results) + }]; +} + +def TriangularSolveOp : Impulse_Op<"triangular_solve", [Pure]> { + let summary = "Solve a triangular linear system"; + let description = [{ + Solves a system of linear equations with a triangular coefficient matrix. + If left_side=true, solves op(A) @ X = B for X. + If left_side=false, solves X @ op(A) = B for X. + op(A) is determined by transpose_a: NO_TRANSPOSE, TRANSPOSE, or ADJOINT. + }]; + + let arguments = (ins + AnyRankedTensor:$a, + AnyRankedTensor:$b, + DefaultValuedAttr:$left_side, + DefaultValuedAttr:$lower, + DefaultValuedAttr:$unit_diagonal, + DefaultValuedAttr:$transpose_a + ); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $a `,` $b attr-dict `:` functional-type(operands, results) + }]; +} + +//===----------------------------------------------------------------------===// +// Tensor manipulation ops +//===----------------------------------------------------------------------===// + +def SelectOp : Impulse_Op<"select", [Pure]> { + let summary = "Exteneded select operation"; + let description = [{ + Extended select operation that supports: + - `tensor` conditions with differently-sized operands + - standard cases supported by `arith.select` + }]; + + let arguments = (ins + AnyRankedTensor:$condition, + AnyType:$true_value, + AnyType:$false_value + ); + + let results = (outs AnyType:$result); + + let assemblyFormat = [{ + $condition `,` $true_value `,` $false_value attr-dict `:` functional-type(operands, results) + }]; +} + +def ReshapeOp : Impulse_Op<"reshape", [Pure]> { + let summary = "Reshape a tensor to a new static shape"; + let arguments = (ins AnyRankedTensor:$input); + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + $input attr-dict `:` functional-type($input, $result) + }]; +} + +def SliceOp : Impulse_Op<"slice", [Pure]> { + let summary = "Extract a static slice from a tensor"; + let description = [{ + Extract a static slice from a tensor. + }]; + + let arguments = (ins + AnyRankedTensor:$operand, + DenseI64ArrayAttr:$start_indices, + DenseI64ArrayAttr:$limit_indices, + DenseI64ArrayAttr:$strides + ); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $operand attr-dict `:` functional-type($operand, $result) + }]; +} + +def DynamicSliceOp : Impulse_Op<"dynamic_slice", [Pure]> { + let summary = "Extract a slice from a tensor at dynamic start indices"; + let description = [{ + Extract a slice from a tensor at dynamic start indices. + }]; + + let arguments = (ins + AnyRankedTensor:$operand, + Variadic:$start_indices, + DenseI64ArrayAttr:$slice_sizes + ); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $operand `,` $start_indices attr-dict `:` functional-type(operands, results) + }]; +} + +def DynamicUpdateSliceOp : Impulse_Op<"dynamic_update_slice", [Pure]> { + let summary = "Update a slice in a tensor at dynamic start indices"; + let description = [{ + Update a slice in a tensor at dynamic start indices. + }]; + + let arguments = (ins + AnyRankedTensor:$operand, + AnyRankedTensor:$update, + Variadic:$start_indices + ); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $operand `,` $update `,` $start_indices attr-dict `:` functional-type(operands, results) + }]; +} + +//===----------------------------------------------------------------------===// +// Math ops +//===----------------------------------------------------------------------===// + +def LogAddExpOp : Impulse_Op<"log_add_exp", [Pure]> { + let summary = "Computes log(exp(x) + exp(y))"; + let description = [{ + Computes log(exp(x) + exp(y)). + }]; + + let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` functional-type(operands, results) + }]; +} + +def LogisticOp : Impulse_Op<"logistic", [Pure]> { + let summary = "Computes logistic (sigmoid) function: 1 / (1 + exp(-x))"; + let description = [{ + Computes the logistic (sigmoid) function: 1 / (1 + exp(-x)). + }]; + + let arguments = (ins AnyRankedTensor:$operand); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $operand attr-dict `:` functional-type(operands, results) + }]; +} + +def PopcountOp : Impulse_Op<"popcount", [Pure, SameOperandsAndResultElementType, SameOperandsAndResultShape]> { + let summary = "Computes population count"; + let description = [{ + Returns the number of 1-bits elementwise. + }]; + + let arguments = (ins AnyRankedTensor:$operand); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $operand attr-dict `:` functional-type(operands, results) + }]; +} + +//===----------------------------------------------------------------------===// +// Debug ops +//===----------------------------------------------------------------------===// + +def DumpOp : Impulse_Op<"dump"> { + let summary = "Debug operation to dump a tensor value at runtime"; + let description = [{ + Debug operation that dumps a tensor value with a label. + }]; + + let arguments = (ins + AnyType:$value, + StrAttr:$label + ); + + let results = (outs AnyType:$output); + + let assemblyFormat = [{ + $value attr-dict `:` functional-type($value, results) + }]; +} + +#endif // IMPULSE_OPS diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index dd6b35d76ae7..fcf10e5c5c1a 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -1035,142 +1035,3 @@ void ForwardDiffRegionOp::getCanonicalizationPatterns( patterns.add, FwdInpOpt, RemoveUnusedArgs>(context); } -//===----------------------------------------------------------------------===// -// SampleOp -//===----------------------------------------------------------------------===// - -LogicalResult SampleOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // TODO: Verify that the result type is same as the type of the referenced - // func.func op. - auto global = - symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); - if (!global) - return emitOpError("'") - << getFn() << "' does not reference a valid global funcOp"; - - if (getLogpdfAttr()) { - auto global = symbolTable.lookupNearestSymbolFrom( - *this, getLogpdfAttr()); - if (!global) - return emitOpError("'") - << getLogpdf().value() << "' does not reference a valid global " - << "funcOp"; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// GenerateOp -//===----------------------------------------------------------------------===// - -LogicalResult GenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // TODO: Verify that the result type is same as the type of the referenced - // func.func op. - auto global = - symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); - if (!global) - return emitOpError("'") - << getFn() << "' does not reference a valid global funcOp"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// SimulateOp -//===----------------------------------------------------------------------===// - -LogicalResult SimulateOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // TODO: Verify that the result type is same as the type of the referenced - // func.func op. - auto global = - symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); - if (!global) - return emitOpError("'") - << getFn() << "' does not reference a valid global funcOp"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// RegenerateOp -//===----------------------------------------------------------------------===// - -LogicalResult -RegenerateOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // TODO: Verify that the result type is same as the type of the referenced - // func.func op. - auto global = - symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); - if (!global) - return emitOpError("'") - << getFn() << "' does not reference a valid global funcOp"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// MHOp -//===----------------------------------------------------------------------===// - -LogicalResult MHOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // TODO: Verify that the result type is same as the type of the referenced - // func.func op. - auto global = - symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); - if (!global) - return emitOpError("'") - << getFn() << "' does not reference a valid global funcOp"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// MCMCOp -//===----------------------------------------------------------------------===// - -LogicalResult MCMCOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - if (auto fnAttr = getFnAttr()) { - auto global = - symbolTable.lookupNearestSymbolFrom(*this, fnAttr); - if (!global) - return emitOpError("'") - << getFn().value() << "' does not reference a valid global funcOp"; - } - - if (auto logpdfAttr = getLogpdfFnAttr()) { - auto global = - symbolTable.lookupNearestSymbolFrom(*this, logpdfAttr); - if (!global) - return emitOpError("'") << logpdfAttr.getValue() - << "' does not reference a valid global funcOp"; - } - - return success(); -} - -LogicalResult MCMCOp::verify() { - bool hasHMC = getHmcConfig().has_value(); - bool hasNUTS = getNutsConfig().has_value(); - - if (hasHMC + hasNUTS != 1) { - return emitOpError( - "Exactly one of hmc_config or nuts_config must be specified"); - } - - // TODO(#2695): More verification - if (!getFnAttr() && !getLogpdfFnAttr()) { - return emitOpError("one of `fn` or `logpdf_fn` must be specified"); - } - - if (getFnAttr() && getLogpdfFnAttr()) { - return emitOpError("specifying both `fn` and `logpdf_fn` is unsupported"); - } - - if (getLogpdfFnAttr() && !getInitialPosition()) { - return emitOpError( - "custom logpdf mode requires `initial_position` to be provided"); - } - - return success(); -} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index f88d212e8c76..d39fd3da387b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -481,4 +481,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces( enzyme::registerFuncDialectAutoDiffInterface(registry); enzyme::registerTensorDialectAutoDiffInterface(registry); enzyme::registerEnzymeDialectAutoDiffInterface(registry); + enzyme::registerImpulseDialectAutoDiffInterface(registry); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index dabce9ecd520..f01b3369a234 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -263,6 +263,7 @@ void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry); void registerEnzymeDialectAutoDiffInterface(DialectRegistry ®istry); +void registerImpulseDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/EnzymeDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/EnzymeDerivatives.td index 32098bce5aab..897e58ff4dfc 100644 --- a/enzyme/Enzyme/MLIR/Implementations/EnzymeDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/EnzymeDerivatives.td @@ -1,4 +1,3 @@ include "Common.td" def : InactiveOp<"enzyme", "IgnoreDerivativesOp">; -def : InactiveOp<"enzyme", "DumpOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..a68f6386c6ff --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,28 @@ +//===- ImpulseAutoDiffOpInterfaceImpl.cpp -------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" + +#include "Dialect/Impulse/Impulse.h" +#include "mlir/IR/TypeSupport.h" + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::impulse; + +namespace { +#include "Implementations/ImpulseDerivatives.inc" +} // namespace + +void mlir::enzyme::registerImpulseDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension( + +[](MLIRContext *context, impulse::ImpulseDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/ImpulseDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ImpulseDerivatives.td new file mode 100644 index 000000000000..1e04268b9cd7 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/ImpulseDerivatives.td @@ -0,0 +1,3 @@ +include "Common.td" + +def : InactiveOp<"impulse", "DumpOp">; diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp index 4a5ba92e5c18..af271c7c3387 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp @@ -3,23 +3,24 @@ #include "mlir/CAPI/IR.h" #include "Dialect/Dialect.h" +#include "Dialect/Impulse/Impulse.h" #include "Dialect/Ops.h" MlirAttribute enzymeRngDistributionAttrGet(MlirContext ctx, EnzymeRngDistribution dist) { - mlir::enzyme::RngDistribution rngDist; + mlir::impulse::RngDistribution rngDist; switch (dist) { case EnzymeRngDistribution_Uniform: - rngDist = mlir::enzyme::RngDistribution::UNIFORM; + rngDist = mlir::impulse::RngDistribution::UNIFORM; break; case EnzymeRngDistribution_Normal: - rngDist = mlir::enzyme::RngDistribution::NORMAL; + rngDist = mlir::impulse::RngDistribution::NORMAL; break; case EnzymeRngDistribution_MultiNormal: - rngDist = mlir::enzyme::RngDistribution::MULTINORMAL; + rngDist = mlir::impulse::RngDistribution::MULTINORMAL; break; } - return wrap(mlir::enzyme::RngDistributionAttr::get(unwrap(ctx), rngDist)); + return wrap(mlir::impulse::RngDistributionAttr::get(unwrap(ctx), rngDist)); } MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, @@ -27,25 +28,25 @@ MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, bool hasUpperBound, double upperBound) { auto *mlirCtx = unwrap(ctx); - mlir::enzyme::SupportKind supportKind; + mlir::impulse::SupportKind supportKind; switch (kind) { case EnzymeSupportKind_Real: - supportKind = mlir::enzyme::SupportKind::REAL; + supportKind = mlir::impulse::SupportKind::REAL; break; case EnzymeSupportKind_Positive: - supportKind = mlir::enzyme::SupportKind::POSITIVE; + supportKind = mlir::impulse::SupportKind::POSITIVE; break; case EnzymeSupportKind_UnitInterval: - supportKind = mlir::enzyme::SupportKind::UNIT_INTERVAL; + supportKind = mlir::impulse::SupportKind::UNIT_INTERVAL; break; case EnzymeSupportKind_Interval: - supportKind = mlir::enzyme::SupportKind::INTERVAL; + supportKind = mlir::impulse::SupportKind::INTERVAL; break; case EnzymeSupportKind_GreaterThan: - supportKind = mlir::enzyme::SupportKind::GREATER_THAN; + supportKind = mlir::impulse::SupportKind::GREATER_THAN; break; case EnzymeSupportKind_LessThan: - supportKind = mlir::enzyme::SupportKind::LESS_THAN; + supportKind = mlir::impulse::SupportKind::LESS_THAN; break; } @@ -59,7 +60,7 @@ MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, upperAttr = mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), upperBound); - return wrap(mlir::enzyme::SupportAttr::get(mlirCtx, supportKind, lowerAttr, + return wrap(mlir::impulse::SupportAttr::get(mlirCtx, supportKind, lowerAttr, upperAttr)); } @@ -69,7 +70,7 @@ MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength, auto trajectoryLengthAttr = mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), trajectoryLength); - return wrap(mlir::enzyme::HMCConfigAttr::get(mlirCtx, trajectoryLengthAttr, + return wrap(mlir::impulse::HMCConfigAttr::get(mlirCtx, trajectoryLengthAttr, adaptStepSize, adaptMassMatrix)); } @@ -84,11 +85,11 @@ MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth, maxDeltaEnergyAttr = mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), maxDeltaEnergy); - return wrap(mlir::enzyme::NUTSConfigAttr::get( + return wrap(mlir::impulse::NUTSConfigAttr::get( mlirCtx, maxTreeDepth, maxDeltaEnergyAttr, adaptStepSize, adaptMassMatrix)); } MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t ptr) { - return wrap(mlir::enzyme::SymbolAttr::get(unwrap(ctx), ptr)); + return wrap(mlir::impulse::SymbolAttr::get(unwrap(ctx), ptr)); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/CMakeLists.txt b/enzyme/Enzyme/MLIR/Interfaces/CMakeLists.txt index 19b5295af0d1..0f1d224efc1f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Interfaces/CMakeLists.txt @@ -22,8 +22,11 @@ add_mlir_library(MLIREnzymeAutoDiffInterface MLIRAutoDiffOpInterfaceIncGen MLIRAutoDiffTypeInterfaceIncGen MLIREnzymeEnumsIncGen + MLIRImpulseOpsIncGen + MLIRImpulseEnumsIncGen LINK_LIBS PUBLIC MLIRIR MLIREnzymeAnalysis + MLIRImpulse ) diff --git a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp index e58bdcbfabff..2f1803331a68 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp @@ -10,6 +10,7 @@ #include "HMCUtils.h" +#include "Dialect/Impulse/Impulse.h" #include "Dialect/Ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -20,7 +21,7 @@ using namespace mlir; using namespace mlir::enzyme; -using namespace mlir::enzyme::MCMC; +using namespace mlir::impulse; SmallVector NUTSTreeState::toValues() const { return {q_left, p_left, grad_left, @@ -60,11 +61,11 @@ SmallVector NUTSTreeState::getTypes() const { return types; } -Value MCMC::conditionalDump(OpBuilder &builder, Location loc, Value value, +Value impulse::conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump) { if (debugDump) { - return enzyme::DumpOp::create(builder, loc, value.getType(), value, - builder.getStringAttr(label)) + return impulse::DumpOp::create(builder, loc, value.getType(), value, + builder.getStringAttr(label)) .getOutput(); } return value; @@ -123,19 +124,19 @@ static Value reverseRowsAndColumns(OpBuilder &builder, Location loc, auto P = createPermutationMatrix(builder, loc, matrixType); // PA = P @ A - auto PA = enzyme::DotOp::create( + auto PA = impulse::DotOp::create( builder, loc, matrixType, P, matrix, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({1}), builder.getDenseI64ArrayAttr({0})); // PAP = PA @ P - return enzyme::DotOp::create( + return impulse::DotOp::create( builder, loc, matrixType, PA, P, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({1}), builder.getDenseI64ArrayAttr({0})); } -Value MCMC::applyInverseMassMatrix(OpBuilder &builder, Location loc, +Value impulse::applyInverseMassMatrix(OpBuilder &builder, Location loc, Value invMass, Value momentum, RankedTensorType positionType) { if (!invMass) { @@ -147,11 +148,11 @@ Value MCMC::applyInverseMassMatrix(OpBuilder &builder, Location loc, if (invMassType.getRank() == 1) { // Diagonal: element-wise auto diagMass = - enzyme::ReshapeOp::create(builder, loc, positionType, invMass); + impulse::ReshapeOp::create(builder, loc, positionType, invMass); return arith::MulFOp::create(builder, loc, diagMass, momentum); } else if (invMassType.getRank() == 2) { // Dense: v = invMass @ p - return enzyme::DotOp::create( + return impulse::DotOp::create( builder, loc, positionType, momentum, invMass, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({1}), builder.getDenseI64ArrayAttr({0})); @@ -162,7 +163,7 @@ Value MCMC::applyInverseMassMatrix(OpBuilder &builder, Location loc, return nullptr; } -Value MCMC::computeKineticEnergy(OpBuilder &builder, Location loc, +Value impulse::computeKineticEnergy(OpBuilder &builder, Location loc, Value momentum, Value invMass, RankedTensorType positionType) { auto elemType = positionType.getElementType(); @@ -178,7 +179,7 @@ Value MCMC::computeKineticEnergy(OpBuilder &builder, Location loc, // K = 0.5 * p^T @ v // For 2D tensors [1, N], contract over both dimensions to get scalar - auto pDotV = enzyme::DotOp::create( + auto pDotV = impulse::DotOp::create( builder, loc, scalarType, momentum, v, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({0, 1}), builder.getDenseI64ArrayAttr({0, 1})); @@ -186,7 +187,7 @@ Value MCMC::computeKineticEnergy(OpBuilder &builder, Location loc, return arith::MulFOp::create(builder, loc, halfConst, pDotV); } -Value MCMC::computeMassMatrixSqrt(OpBuilder &builder, Location loc, +Value impulse::computeMassMatrixSqrt(OpBuilder &builder, Location loc, Value invMass, RankedTensorType positionType) { if (!invMass) { @@ -211,24 +212,24 @@ Value MCMC::computeMassMatrixSqrt(OpBuilder &builder, Location loc, // https://github.com/pyro-ppl/numpyro/blob/6a9cb9a530fe53897edb6c472368e58965b034e4/numpyro/infer/hmc_util.py#L499 auto reversedInvMass = reverseRowsAndColumns(builder, loc, invMass); auto L_reversed = - enzyme::CholeskyOp::create(builder, loc, invMassType, reversedInvMass, + impulse::CholeskyOp::create(builder, loc, invMassType, reversedInvMass, /*lower=*/builder.getBoolAttr(true)); auto massMatrixSqrtInvT = reverseRowsAndColumns(builder, loc, L_reversed); auto identityMatrix = createIdentityMatrix(builder, loc, invMassType); - auto massMatrixSqrt = enzyme::TriangularSolveOp::create( + auto massMatrixSqrt = impulse::TriangularSolveOp::create( builder, loc, invMassType, massMatrixSqrtInvT, identityMatrix, /*left_side=*/builder.getBoolAttr(true), /*lower=*/builder.getBoolAttr(false), /*unit_diagonal=*/builder.getBoolAttr(false), /*transpose_a=*/ - enzyme::TransposeAttr::get(builder.getContext(), - enzyme::Transpose::TRANSPOSE)); + impulse::TransposeAttr::get(builder.getContext(), + impulse::Transpose::TRANSPOSE)); return massMatrixSqrt; } } -std::pair MCMC::sampleMomentum(OpBuilder &builder, Location loc, +std::pair impulse::sampleMomentum(OpBuilder &builder, Location loc, Value rng, Value invMass, Value massMatrixSqrt, RankedTensorType positionType, @@ -247,11 +248,11 @@ std::pair MCMC::sampleMomentum(OpBuilder &builder, Location loc, debugDump); // Sample eps ~ N(0, I) - auto randomOp = enzyme::RandomOp::create( + auto randomOp = impulse::RandomOp::create( builder, loc, TypeRange{rng.getType(), positionType}, rng, zeroConst, oneConst, - enzyme::RngDistributionAttr::get(builder.getContext(), - enzyme::RngDistribution::NORMAL)); + impulse::RngDistributionAttr::get(builder.getContext(), + impulse::RngDistribution::NORMAL)); auto rngOut = randomOp.getOutputRngState(); auto eps = randomOp.getResult(); @@ -265,12 +266,12 @@ std::pair MCMC::sampleMomentum(OpBuilder &builder, Location loc, if (massMatrixSqrtType.getRank() == 1) { // Diagonal: p = massMatrixSqrt * eps (element-wise) auto diagSqrt = - enzyme::ReshapeOp::create(builder, loc, positionType, massMatrixSqrt); + impulse::ReshapeOp::create(builder, loc, positionType, massMatrixSqrt); auto p = arith::MulFOp::create(builder, loc, diagSqrt, eps); return {p, rngOut}; } else { // Dense: p = massMatrixSqrt @ eps - auto p = enzyme::DotOp::create( + auto p = impulse::DotOp::create( builder, loc, positionType, eps, massMatrixSqrt, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({1}), builder.getDenseI64ArrayAttr({1})); @@ -297,7 +298,7 @@ static Value scatterPositionToTrace(OpBuilder &builder, Location loc, DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(info.offset))); SmallVector extractIndices{c0, posOffset}; - auto slice = enzyme::DynamicSliceOp::create( + auto slice = impulse::DynamicSliceOp::create( builder, loc, sliceType, position2d, extractIndices, builder.getDenseI64ArrayAttr({1, info.size})); @@ -306,7 +307,7 @@ static Value scatterPositionToTrace(OpBuilder &builder, Location loc, DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(info.traceOffset))); SmallVector updateIndices{c0, traceOffset}; - result = enzyme::DynamicUpdateSliceOp::create(builder, loc, traceType, + result = impulse::DynamicUpdateSliceOp::create(builder, loc, traceType, result, slice, updateIndices); } return result; @@ -334,7 +335,7 @@ static Value gatherPositionFromTrace(OpBuilder &builder, Location loc, DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(info.traceOffset))); SmallVector extractIndices{c0, traceOffset}; - auto slice = enzyme::DynamicSliceOp::create( + auto slice = impulse::DynamicSliceOp::create( builder, loc, sliceType, fullTrace, extractIndices, builder.getDenseI64ArrayAttr({1, info.size})); @@ -343,13 +344,13 @@ static Value gatherPositionFromTrace(OpBuilder &builder, Location loc, DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(info.offset))); SmallVector updateIndices{c0, posOffset}; - result = enzyme::DynamicUpdateSliceOp::create(builder, loc, positionType2d, + result = impulse::DynamicUpdateSliceOp::create(builder, loc, positionType2d, result, slice, updateIndices); } return result; } -GradientResult MCMC::computePotentialAndGradient(OpBuilder &builder, +GradientResult impulse::computePotentialAndGradient(OpBuilder &builder, Location loc, Value position, Value rng, const HMCContext &ctx) { @@ -368,7 +369,7 @@ GradientResult MCMC::computePotentialAndGradient(OpBuilder &builder, auto autodiffGradType = positionType; if (isCustomLogpdf) { autodiffPosition = - enzyme::ReshapeOp::create(builder, loc, flatType, position); + impulse::ReshapeOp::create(builder, loc, flatType, position); autodiffPositionType = flatType; autodiffGradType = flatType; } @@ -428,7 +429,7 @@ GradientResult MCMC::computePotentialAndGradient(OpBuilder &builder, generateResultTypes.append(ctx.fnResultTypes.begin(), ctx.fnResultTypes.end()); - auto generateOp = enzyme::GenerateOp::create( + auto generateOp = impulse::GenerateOp::create( builder, loc, generateResultTypes, ctx.fn, generateInputs, fullTrace, ctx.allAddresses, ctx.allAddresses, builder.getStringAttr("")); @@ -447,7 +448,7 @@ GradientResult MCMC::computePotentialAndGradient(OpBuilder &builder, Value grad = autodiffOp.getResult(2); if (isCustomLogpdf) { - grad = enzyme::ReshapeOp::create(builder, loc, positionType, grad); + grad = impulse::ReshapeOp::create(builder, loc, positionType, grad); } return { @@ -457,7 +458,7 @@ GradientResult MCMC::computePotentialAndGradient(OpBuilder &builder, }; } -IntegrationResult MCMC::computeIntegrationStep(OpBuilder &builder, Location loc, +IntegrationResult impulse::computeIntegrationStep(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const HMCContext &ctx) { @@ -466,7 +467,7 @@ IntegrationResult MCMC::computeIntegrationStep(OpBuilder &builder, Location loc, auto elemType = ctx.getElementType(); auto negStepSize = arith::NegFOp::create(builder, loc, ctx.stepSize); - Value signedStepSize = enzyme::SelectOp::create( + Value signedStepSize = impulse::SelectOp::create( builder, loc, scalarType, direction, ctx.stepSize, negStepSize); auto halfConst = arith::ConstantOp::create( @@ -475,12 +476,12 @@ IntegrationResult MCMC::computeIntegrationStep(OpBuilder &builder, Location loc, ArrayRef shape = positionType.getShape(); auto stepSizeBroadcast = - enzyme::BroadcastOp::create(builder, loc, positionType, signedStepSize, + BroadcastOp::create(builder, loc, positionType, signedStepSize, builder.getDenseI64ArrayAttr(shape)); auto halfStep = arith::MulFOp::create(builder, loc, halfConst, signedStepSize); auto halfStepBroadcast = - enzyme::BroadcastOp::create(builder, loc, positionType, halfStep, + BroadcastOp::create(builder, loc, positionType, halfStep, builder.getDenseI64ArrayAttr(shape)); // 1. Half step momentum: p_half = p - 0.5 * eps * grad @@ -505,7 +506,7 @@ IntegrationResult MCMC::computeIntegrationStep(OpBuilder &builder, Location loc, return {qNew, pNew, gradResult.grad, gradResult.U, gradResult.rng}; } -Value MCMC::checkTurning(OpBuilder &builder, Location loc, Value pLeft, +Value impulse::checkTurning(OpBuilder &builder, Location loc, Value pLeft, Value pRight, Value pSum, const NUTSContext &ctx) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); @@ -524,7 +525,7 @@ Value MCMC::checkTurning(OpBuilder &builder, Location loc, Value pLeft, applyInverseMassMatrix(builder, loc, ctx.invMass, pRight, positionType); // p_sum_centered = p_sum - (p_left + p_right) / 2 - auto halfBroadcast = enzyme::BroadcastOp::create( + auto halfBroadcast = BroadcastOp::create( builder, loc, positionType, halfConst, builder.getDenseI64ArrayAttr(positionType.getShape())); @@ -533,12 +534,12 @@ Value MCMC::checkTurning(OpBuilder &builder, Location loc, Value pLeft, arith::MulFOp::create(builder, loc, halfBroadcast, pLeftPlusPRight); Value pSumCentered = arith::SubFOp::create(builder, loc, pSum, halfSum); - auto leftAngle = enzyme::DotOp::create( + auto leftAngle = impulse::DotOp::create( builder, loc, scalarType, vLeft, pSumCentered, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({0, 1}), builder.getDenseI64ArrayAttr({0, 1})); - auto rightAngle = enzyme::DotOp::create( + auto rightAngle = impulse::DotOp::create( builder, loc, scalarType, vRight, pSumCentered, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({0, 1}), @@ -553,15 +554,15 @@ Value MCMC::checkTurning(OpBuilder &builder, Location loc, Value pLeft, return arith::OrIOp::create(builder, loc, leftNeg, rightNeg); } -Value MCMC::computeUniformTransitionProb(OpBuilder &builder, Location loc, +Value impulse::computeUniformTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight) { Value weightDiff = arith::SubFOp::create(builder, loc, newWeight, currentWeight); - return enzyme::LogisticOp::create(builder, loc, weightDiff.getType(), + return impulse::LogisticOp::create(builder, loc, weightDiff.getType(), weightDiff); } -Value MCMC::computeBiasedTransitionProb(OpBuilder &builder, Location loc, +Value impulse::computeBiasedTransitionProb(OpBuilder &builder, Location loc, Value currentWeight, Value newWeight, Value turning, Value diverging) { auto resultType = cast(currentWeight.getType()); @@ -585,7 +586,7 @@ Value MCMC::computeBiasedTransitionProb(OpBuilder &builder, Location loc, zeroConst, clippedProb); } -NUTSTreeState MCMC::combineTrees(OpBuilder &builder, Location loc, +NUTSTreeState impulse::combineTrees(OpBuilder &builder, Location loc, const NUTSTreeState &tree, const NUTSTreeState &subTree, Value direction, Value rng, bool biased, @@ -602,21 +603,21 @@ NUTSTreeState MCMC::combineTrees(OpBuilder &builder, Location loc, builder, loc, scalarType, DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0))); - auto qLeft = enzyme::SelectOp::create(builder, loc, positionType, direction, + auto qLeft = impulse::SelectOp::create(builder, loc, positionType, direction, tree.q_left, subTree.q_left); - auto pLeft = enzyme::SelectOp::create(builder, loc, positionType, direction, + auto pLeft = impulse::SelectOp::create(builder, loc, positionType, direction, tree.p_left, subTree.p_left); - auto gradLeft = enzyme::SelectOp::create( + auto gradLeft = impulse::SelectOp::create( builder, loc, positionType, direction, tree.grad_left, subTree.grad_left); - auto qRight = enzyme::SelectOp::create(builder, loc, positionType, direction, + auto qRight = impulse::SelectOp::create(builder, loc, positionType, direction, subTree.q_right, tree.q_right); - auto pRight = enzyme::SelectOp::create(builder, loc, positionType, direction, + auto pRight = impulse::SelectOp::create(builder, loc, positionType, direction, subTree.p_right, tree.p_right); auto gradRight = - enzyme::SelectOp::create(builder, loc, positionType, direction, + impulse::SelectOp::create(builder, loc, positionType, direction, subTree.grad_right, tree.grad_right); - auto combinedWeight = enzyme::LogAddExpOp::create( + auto combinedWeight = impulse::LogAddExpOp::create( builder, loc, scalarType, tree.weight, subTree.weight); // Compute transition probability @@ -630,11 +631,11 @@ NUTSTreeState MCMC::combineTrees(OpBuilder &builder, Location loc, computeUniformTransitionProb(builder, loc, tree.weight, subTree.weight); } - auto randomOp = enzyme::RandomOp::create( + auto randomOp = impulse::RandomOp::create( builder, loc, TypeRange{rng.getType(), scalarType}, rng, zeroConst, oneConst, - enzyme::RngDistributionAttr::get(builder.getContext(), - enzyme::RngDistribution::UNIFORM)); + impulse::RngDistributionAttr::get(builder.getContext(), + impulse::RngDistribution::UNIFORM)); auto rngOut = randomOp.getOutputRngState(); auto uniformSample = randomOp.getResult(); @@ -642,14 +643,14 @@ NUTSTreeState MCMC::combineTrees(OpBuilder &builder, Location loc, builder, loc, arith::CmpFPredicate::OLT, uniformSample, transitionProb); auto qProposal = - enzyme::SelectOp::create(builder, loc, positionType, acceptNew, + impulse::SelectOp::create(builder, loc, positionType, acceptNew, subTree.q_proposal, tree.q_proposal); auto gradProposal = - enzyme::SelectOp::create(builder, loc, positionType, acceptNew, + impulse::SelectOp::create(builder, loc, positionType, acceptNew, subTree.grad_proposal, tree.grad_proposal); - auto UProposal = enzyme::SelectOp::create( + auto UProposal = impulse::SelectOp::create( builder, loc, scalarType, acceptNew, subTree.U_proposal, tree.U_proposal); - auto HProposal = enzyme::SelectOp::create( + auto HProposal = impulse::SelectOp::create( builder, loc, scalarType, acceptNew, subTree.H_proposal, tree.H_proposal); auto oneI64 = arith::ConstantOp::create( @@ -696,16 +697,16 @@ NUTSTreeState MCMC::combineTrees(OpBuilder &builder, Location loc, .rng = rngOut}; } -InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng, +InitialHMCState impulse::InitHMC(OpBuilder &builder, Location loc, Value rng, const HMCContext &ctx, Value initialPosition, bool debugDump) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); auto elemType = ctx.getElementType(); - auto initSplit = enzyme::RandomSplitOp::create( + auto initSplit = impulse::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType()}, rng); - auto kernelSplit = enzyme::RandomSplitOp::create( + auto kernelSplit = impulse::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()}, initSplit.getResult(0)); auto rngForSampleKernel = kernelSplit.getResult(0); @@ -717,7 +718,7 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng, if (ctx.hasCustomLogpdf()) { q0 = initialPosition; auto flatType = RankedTensorType::get({ctx.positionSize}, elemType); - auto q0Flat = enzyme::ReshapeOp::create(builder, loc, flatType, q0); + auto q0Flat = impulse::ReshapeOp::create(builder, loc, flatType, q0); SmallVector callArgs; callArgs.push_back(q0Flat); callArgs.append(ctx.fnInputs.begin(), ctx.fnInputs.end()); @@ -749,7 +750,7 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng, generateResultTypesInit.append(ctx.fnResultTypes.begin(), ctx.fnResultTypes.end()); - auto generateOpInit = enzyme::GenerateOp::create( + auto generateOpInit = impulse::GenerateOp::create( builder, loc, generateResultTypesInit, ctx.fn, generateInputsInit, fullTraceInit, ctx.allAddresses, ctx.allAddresses, builder.getStringAttr("")); @@ -768,7 +769,7 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng, auto autodiffQ0Type = positionType; auto autodiffGradType = positionType; if (isCustomLogpdf) { - autodiffQ0 = enzyme::ReshapeOp::create(builder, loc, flatType, q0); + autodiffQ0 = impulse::ReshapeOp::create(builder, loc, flatType, q0); autodiffQ0Type = flatType; autodiffGradType = flatType; } @@ -832,7 +833,7 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng, generateResultTypesInner.append(ctx.fnResultTypes.begin(), ctx.fnResultTypes.end()); - auto generateOpInner = enzyme::GenerateOp::create( + auto generateOpInner = impulse::GenerateOp::create( builder, loc, generateResultTypesInner, ctx.fn, generateInputsInner, fullTraceInner, ctx.allAddresses, ctx.allAddresses, builder.getStringAttr("")); @@ -851,13 +852,13 @@ InitialHMCState MCMC::InitHMC(OpBuilder &builder, Location loc, Value rng, Value grad0 = autodiffInit.getResult(2); if (isCustomLogpdf) { - grad0 = enzyme::ReshapeOp::create(builder, loc, positionType, grad0); + grad0 = impulse::ReshapeOp::create(builder, loc, positionType, grad0); } return {q0, U0, grad0, rngForSampleKernel}; } -MCMCKernelResult MCMC::SampleHMC(OpBuilder &builder, Location loc, Value q, +MCMCKernelResult impulse::SampleHMC(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const HMCContext &ctx, bool debugDump) { auto positionType = ctx.getPositionType(); @@ -880,7 +881,7 @@ MCMCKernelResult MCMC::SampleHMC(OpBuilder &builder, Location loc, Value q, auto adjustedCtx = ctx.withStepSize(adjustedStepSize); // 1. Split RNG: [rngNext, rngMomentum, rngTransition] - auto sampleKernelSplit = enzyme::RandomSplitOp::create( + auto sampleKernelSplit = impulse::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()}, rng); auto rngNext = sampleKernelSplit.getResult(0); @@ -890,7 +891,7 @@ MCMCKernelResult MCMC::SampleHMC(OpBuilder &builder, Location loc, Value q, // 2. Sample fresh momentum p ~ N(0, M) Value rngForMomentum = rngMomentum; if (!ctx.hasCustomLogpdf()) { - auto momSplit = enzyme::RandomSplitOp::create( + auto momSplit = impulse::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType()}, rngMomentum); rngForMomentum = momSplit.getResult(0); } @@ -920,7 +921,7 @@ MCMCKernelResult MCMC::SampleHMC(OpBuilder &builder, Location loc, Value q, SmallVector loopResultTypes = {positionType, positionType, positionType, scalarType, rngTransition.getType()}; auto forLoopOp = - enzyme::ForLoopOp::create(builder, loc, loopResultTypes, c0, numSteps, c1, + impulse::ForOp::create(builder, loc, loopResultTypes, c0, numSteps, c1, ValueRange{q, p0, grad, U, rngTransition}); Block *loopBody = builder.createBlock(&forLoopOp.getRegion()); @@ -942,7 +943,7 @@ MCMCKernelResult MCMC::SampleHMC(OpBuilder &builder, Location loc, Value q, adjustedCtx); // Yield [q, p, grad, U, rng] - enzyme::YieldOp::create( + impulse::YieldOp::create( builder, loc, ValueRange{step.q, step.p, step.grad, step.U, step.rng}); builder.setInsertionPointAfter(forLoopOp); @@ -971,11 +972,11 @@ MCMCKernelResult MCMC::SampleHMC(OpBuilder &builder, Location loc, Value q, auto accProb = arith::MinimumFOp::create(builder, loc, oneConst, expDH); // u ~ Uniform(0, 1) - auto randomOp = enzyme::RandomOp::create( + auto randomOp = impulse::RandomOp::create( builder, loc, TypeRange{rngAfterLeapfrog.getType(), scalarType}, rngAfterLeapfrog, zeroConst, oneConst, - enzyme::RngDistributionAttr::get(builder.getContext(), - enzyme::RngDistribution::UNIFORM)); + impulse::RngDistributionAttr::get(builder.getContext(), + impulse::RngDistribution::UNIFORM)); auto randUniform = randomOp.getResult(); // accepted = u < α @@ -983,17 +984,17 @@ MCMCKernelResult MCMC::SampleHMC(OpBuilder &builder, Location loc, Value q, builder, loc, arith::CmpFPredicate::OLT, randUniform, accProb); // 8. Select between original and proposal - auto qFinal = enzyme::SelectOp::create(builder, loc, positionType, + auto qFinal = impulse::SelectOp::create(builder, loc, positionType, acceptedTensor, qProposal, q); - auto gradFinal = enzyme::SelectOp::create(builder, loc, positionType, + auto gradFinal = impulse::SelectOp::create(builder, loc, positionType, acceptedTensor, gradProposal, grad); - auto UFinal = enzyme::SelectOp::create(builder, loc, scalarType, + auto UFinal = impulse::SelectOp::create(builder, loc, scalarType, acceptedTensor, UProposal, U); return {qFinal, gradFinal, UFinal, acceptedTensor, accProb, rngNext}; } -MCMCKernelResult MCMC::SampleNUTS(OpBuilder &builder, Location loc, Value q, +MCMCKernelResult impulse::SampleNUTS(OpBuilder &builder, Location loc, Value q, Value grad, Value U, Value rng, const NUTSContext &ctx, bool debugDump) { auto positionType = ctx.getPositionType(); @@ -1003,7 +1004,7 @@ MCMCKernelResult MCMC::SampleNUTS(OpBuilder &builder, Location loc, Value q, auto i1TensorType = RankedTensorType::get({}, builder.getI1Type()); // 1. Split RNG: [rngNext, rngMomentum, rngTree] - auto sampleKernelSplit = enzyme::RandomSplitOp::create( + auto sampleKernelSplit = impulse::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType(), rng.getType()}, rng); auto rngNext = sampleKernelSplit.getResult(0); @@ -1013,7 +1014,7 @@ MCMCKernelResult MCMC::SampleNUTS(OpBuilder &builder, Location loc, Value q, // 2. Sample fresh momentum p ~ N(0, M) Value rngForMomentum = rngMomentum; if (!ctx.hasCustomLogpdf()) { - auto momSplit = enzyme::RandomSplitOp::create( + auto momSplit = impulse::RandomSplitOp::create( builder, loc, TypeRange{rng.getType(), rng.getType()}, rngMomentum); rngForMomentum = momSplit.getResult(0); } @@ -1083,7 +1084,7 @@ MCMCKernelResult MCMC::SampleNUTS(OpBuilder &builder, Location loc, Value q, meanAcceptProb, rngNext}; } -NUTSTreeState MCMC::buildBaseTree(OpBuilder &builder, Location loc, +NUTSTreeState impulse::buildBaseTree(OpBuilder &builder, Location loc, const IntegratorState &leaf, Value rng, Value direction, const NUTSContext &ctx) { auto positionType = ctx.getPositionType(); @@ -1161,21 +1162,21 @@ NUTSTreeState MCMC::buildBaseTree(OpBuilder &builder, Location loc, .rng = rngOut}; } -IntegratorState MCMC::getLeafFromTree(OpBuilder &builder, Location loc, +IntegratorState impulse::getLeafFromTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, const NUTSContext &ctx) { auto positionType = ctx.getPositionType(); - auto leafQ = enzyme::SelectOp::create(builder, loc, positionType, direction, + auto leafQ = impulse::SelectOp::create(builder, loc, positionType, direction, tree.q_right, tree.q_left); - auto leafP = enzyme::SelectOp::create(builder, loc, positionType, direction, + auto leafP = impulse::SelectOp::create(builder, loc, positionType, direction, tree.p_right, tree.p_left); - auto leafGrad = enzyme::SelectOp::create( + auto leafGrad = impulse::SelectOp::create( builder, loc, positionType, direction, tree.grad_right, tree.grad_left); return {leafQ, leafP, leafGrad}; } -SubtreeBuildResult MCMC::buildIterativeSubtree(OpBuilder &builder, Location loc, +SubtreeBuildResult impulse::buildIterativeSubtree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, Value direction, Value pCkpts, Value pSumCkpts, @@ -1210,7 +1211,7 @@ SubtreeBuildResult MCMC::buildIterativeSubtree(OpBuilder &builder, Location loc, whileInitVals.push_back(zeroI64); auto whileOp = - enzyme::WhileLoopOp::create(builder, loc, whileTypes, whileInitVals); + impulse::WhileOp::create(builder, loc, whileTypes, whileInitVals); // Check: num_proposals < max_num_proposals && !turning && !diverging Block *condBlock = builder.createBlock(&whileOp.getConditionRegion()); @@ -1235,7 +1236,7 @@ SubtreeBuildResult MCMC::buildIterativeSubtree(OpBuilder &builder, Location loc, notDiverging); // Yield continue condition - enzyme::YieldOp::create(builder, loc, ValueRange{continueCond}); + impulse::YieldOp::create(builder, loc, ValueRange{continueCond}); Block *bodyBlock = builder.createBlock(&whileOp.getBodyRegion()); for (auto type : whileTypes) @@ -1254,7 +1255,7 @@ SubtreeBuildResult MCMC::buildIterativeSubtree(OpBuilder &builder, Location loc, IntegratorState leaf = getLeafFromTree(builder, loc, bodyTree, direction, ctx); - auto rngSplit2 = enzyme::RandomSplitOp::create( + auto rngSplit2 = impulse::RandomSplitOp::create( builder, loc, TypeRange{bodyTree.rng.getType(), bodyTree.rng.getType()}, bodyTree.rng); auto rngNext = rngSplit2.getResult(0); @@ -1269,11 +1270,11 @@ SubtreeBuildResult MCMC::buildIterativeSubtree(OpBuilder &builder, Location loc, builder, loc, arith::CmpIPredicate::eq, bodyTree.num_proposals, zeroI64); SmallVector treeTypes = newLeaf.getTypes(); - auto ifOp = enzyme::IfOp::create(builder, loc, treeTypes, isFirstLeaf); + auto ifOp = impulse::IfOp::create(builder, loc, treeTypes, isFirstLeaf); { Block *trueBranch = builder.createBlock(&ifOp.getTrueBranch()); builder.setInsertionPointToStart(trueBranch); - enzyme::YieldOp::create(builder, loc, newLeaf.toValues()); + impulse::YieldOp::create(builder, loc, newLeaf.toValues()); } { Block *falseBranch = builder.createBlock(&ifOp.getFalseBranch()); @@ -1281,7 +1282,7 @@ SubtreeBuildResult MCMC::buildIterativeSubtree(OpBuilder &builder, Location loc, NUTSTreeState combinedTree = combineTrees(builder, loc, bodyTree, newLeaf, direction, rngCombine, /*biased=*/false, ctx); - enzyme::YieldOp::create(builder, loc, combinedTree.toValues()); + impulse::YieldOp::create(builder, loc, combinedTree.toValues()); } builder.setInsertionPointAfter(ifOp); @@ -1300,7 +1301,7 @@ SubtreeBuildResult MCMC::buildIterativeSubtree(OpBuilder &builder, Location loc, updatedPSumCkpts, ckptIdxMin, ckptIdxMax, ctx, debugDump); updatedTree.turning = - enzyme::SelectOp::create(builder, loc, i1TensorType, isFirstLeaf, + impulse::SelectOp::create(builder, loc, i1TensorType, isFirstLeaf, newLeaf.turning, iterativeTurning); auto nextLeafIdx = arith::AddIOp::create(builder, loc, bodyLeafIdx, oneI64); @@ -1309,7 +1310,7 @@ SubtreeBuildResult MCMC::buildIterativeSubtree(OpBuilder &builder, Location loc, yieldVals.push_back(updatedPCkpts); yieldVals.push_back(updatedPSumCkpts); yieldVals.push_back(nextLeafIdx); - enzyme::YieldOp::create(builder, loc, yieldVals); + impulse::YieldOp::create(builder, loc, yieldVals); builder.setInsertionPointAfter(whileOp); @@ -1326,11 +1327,11 @@ SubtreeBuildResult MCMC::buildIterativeSubtree(OpBuilder &builder, Location loc, return {resultTree, resultPCkpts, resultPSumCkpts}; } -SubtreeBuildResult MCMC::doubleTree(OpBuilder &builder, Location loc, +SubtreeBuildResult impulse::doubleTree(OpBuilder &builder, Location loc, const NUTSTreeState &tree, Value direction, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump) { - auto rngSplit2 = enzyme::RandomSplitOp::create( + auto rngSplit2 = impulse::RandomSplitOp::create( builder, loc, TypeRange{tree.rng.getType(), tree.rng.getType()}, tree.rng); auto rngSubtree = rngSplit2.getResult(0); @@ -1349,7 +1350,7 @@ SubtreeBuildResult MCMC::doubleTree(OpBuilder &builder, Location loc, return {combinedTree, subtreeResult.pCkpts, subtreeResult.pSumCkpts}; } -NUTSTreeState MCMC::buildTree(OpBuilder &builder, Location loc, +NUTSTreeState impulse::buildTree(OpBuilder &builder, Location loc, const NUTSTreeState &initialTree, const NUTSContext &ctx, bool debugDump) { auto elemType = @@ -1383,7 +1384,7 @@ NUTSTreeState MCMC::buildTree(OpBuilder &builder, Location loc, SmallVector whileInitVals = initialTree.toValues(); auto whileOp = - enzyme::WhileLoopOp::create(builder, loc, whileTypes, whileInitVals); + impulse::WhileOp::create(builder, loc, whileTypes, whileInitVals); // Check: (depth < maxTreeDepth) && !turning && !diverging Block *condBlock = builder.createBlock(&whileOp.getConditionRegion()); @@ -1408,7 +1409,7 @@ NUTSTreeState MCMC::buildTree(OpBuilder &builder, Location loc, builder, loc, arith::AndIOp::create(builder, loc, depthCheck, notTurning), notDiverging); - enzyme::YieldOp::create(builder, loc, ValueRange{continueCond}); + impulse::YieldOp::create(builder, loc, ValueRange{continueCond}); Block *bodyBlock = builder.createBlock(&whileOp.getBodyRegion()); for (auto type : whileTypes) @@ -1427,7 +1428,7 @@ NUTSTreeState MCMC::buildTree(OpBuilder &builder, Location loc, Value bodyPCkpts = zeroCkpts; Value bodyPSumCkpts = zeroCkpts; - auto rngSplit3 = enzyme::RandomSplitOp::create( + auto rngSplit3 = impulse::RandomSplitOp::create( builder, loc, TypeRange{bodyTree.rng.getType(), bodyTree.rng.getType(), bodyTree.rng.getType()}, @@ -1436,11 +1437,11 @@ NUTSTreeState MCMC::buildTree(OpBuilder &builder, Location loc, auto rngDir = rngSplit3.getResult(1); auto rngDbl = rngSplit3.getResult(2); - auto directionRandom = enzyme::RandomOp::create( + auto directionRandom = impulse::RandomOp::create( builder, loc, TypeRange{rngDir.getType(), F64TensorType}, rngDir, zeroConst, oneConst, - enzyme::RngDistributionAttr::get(builder.getContext(), - enzyme::RngDistribution::UNIFORM)); + impulse::RngDistributionAttr::get(builder.getContext(), + impulse::RngDistribution::UNIFORM)); auto direction = arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::OLT, directionRandom.getResult(), halfConst); @@ -1453,7 +1454,7 @@ NUTSTreeState MCMC::buildTree(OpBuilder &builder, Location loc, NUTSTreeState treeToYield = doubleResult.tree; treeToYield.rng = rngNext; - enzyme::YieldOp::create(builder, loc, treeToYield.toValues()); + impulse::YieldOp::create(builder, loc, treeToYield.toValues()); builder.setInsertionPointAfter(whileOp); @@ -1463,7 +1464,7 @@ NUTSTreeState MCMC::buildTree(OpBuilder &builder, Location loc, } std::pair -MCMC::leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx) { +impulse::leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx) { auto i64TensorType = cast(leafIdx.getType()); auto oneConst = arith::ConstantOp::create( @@ -1473,7 +1474,7 @@ MCMC::leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx) { // idx_max = popcount(leafIdx >> 1) auto shiftedIdx = arith::ShRUIOp::create(builder, loc, leafIdx, oneConst); auto idxMax = - enzyme::PopcountOp::create(builder, loc, i64TensorType, shiftedIdx); + impulse::PopcountOp::create(builder, loc, i64TensorType, shiftedIdx); // num_subtrees = popcount((~leafIdx & (leafIdx + 1)) - 1) auto leafIdxPlusOne = arith::AddIOp::create(builder, loc, leafIdx, oneConst); @@ -1487,7 +1488,7 @@ MCMC::leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx) { arith::AndIOp::create(builder, loc, notLeafIdx, leafIdxPlusOne); Value andMinusOne = arith::SubIOp::create(builder, loc, andResult, oneConst); Value numSubtrees = - enzyme::PopcountOp::create(builder, loc, i64TensorType, andMinusOne); + impulse::PopcountOp::create(builder, loc, i64TensorType, andMinusOne); // idx_min = idx_max - num_subtrees + 1 Value idxMaxMinusNumSubtrees = @@ -1498,7 +1499,7 @@ MCMC::leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx) { return {idxMin, idxMax}; } -Value MCMC::checkIterativeTurning(OpBuilder &builder, Location loc, Value p, +Value impulse::checkIterativeTurning(OpBuilder &builder, Location loc, Value p, Value pSum, Value pCkpts, Value pSumCkpts, Value idxMin, Value idxMax, const NUTSContext &ctx, bool debugDump) { @@ -1518,7 +1519,7 @@ Value MCMC::checkIterativeTurning(OpBuilder &builder, Location loc, Value p, SmallVector whileInitVals = {idxMax, falseConst}; auto whileOp = - enzyme::WhileLoopOp::create(builder, loc, whileTypes, whileInitVals); + impulse::WhileOp::create(builder, loc, whileTypes, whileInitVals); Block *condBlock = builder.createBlock(&whileOp.getConditionRegion()); condBlock->addArgument(i64TensorType, loc); condBlock->addArgument(i1TensorType, loc); @@ -1535,7 +1536,7 @@ Value MCMC::checkIterativeTurning(OpBuilder &builder, Location loc, Value p, auto notTurning = arith::XOrIOp::create(builder, loc, turningCond, trueConst); auto continueLoop = arith::AndIOp::create(builder, loc, iGeMin, notTurning); - enzyme::YieldOp::create(builder, loc, ValueRange{continueLoop.getResult()}); + impulse::YieldOp::create(builder, loc, ValueRange{continueLoop.getResult()}); Block *bodyBlock = builder.createBlock(&whileOp.getBodyRegion()); bodyBlock->addArgument(i64TensorType, loc); @@ -1546,10 +1547,10 @@ Value MCMC::checkIterativeTurning(OpBuilder &builder, Location loc, Value p, auto zeroI64 = arith::ConstantOp::create( builder, loc, i64TensorType, DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(0))); - Value pLeft = enzyme::DynamicSliceOp::create( + Value pLeft = impulse::DynamicSliceOp::create( builder, loc, positionType, pCkpts, ValueRange{iBody, zeroI64}, builder.getDenseI64ArrayAttr({1, ctx.positionSize})); - Value pSumCkptI = enzyme::DynamicSliceOp::create( + Value pSumCkptI = impulse::DynamicSliceOp::create( builder, loc, positionType, pSumCkpts, ValueRange{iBody, zeroI64}, builder.getDenseI64ArrayAttr({1, ctx.positionSize})); @@ -1561,14 +1562,14 @@ Value MCMC::checkIterativeTurning(OpBuilder &builder, Location loc, Value p, Value turningAtCkpt = checkTurning(builder, loc, pLeft, p, subtreePSum, ctx); Value iNext = arith::SubIOp::create(builder, loc, iBody, oneI64); - enzyme::YieldOp::create(builder, loc, ValueRange{iNext, turningAtCkpt}); + impulse::YieldOp::create(builder, loc, ValueRange{iNext, turningAtCkpt}); builder.setInsertionPointAfter(whileOp); return whileOp.getResult(1); } std::pair -MCMC::updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, +impulse::updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, Value ckptIdxMax, Value p, Value pSum, Value pCkpts, Value pSumCkpts, const NUTSContext &ctx, bool debugDump) { @@ -1588,25 +1589,25 @@ MCMC::updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, // Compute updates only on even leafIdx SmallVector ifResultTypes = {pCkptsType, pCkptsType}; - auto ifOp = enzyme::IfOp::create(builder, loc, ifResultTypes, isEven); + auto ifOp = impulse::IfOp::create(builder, loc, ifResultTypes, isEven); { Block *trueBranch = builder.createBlock(&ifOp.getTrueBranch()); builder.setInsertionPointToStart(trueBranch); - auto updatedPCkpts = enzyme::DynamicUpdateSliceOp::create( + auto updatedPCkpts = impulse::DynamicUpdateSliceOp::create( builder, loc, pCkptsType, pCkpts, p, ValueRange{ckptIdxMax, zeroI64}); - auto updatedPSumCkpts = enzyme::DynamicUpdateSliceOp::create( + auto updatedPSumCkpts = impulse::DynamicUpdateSliceOp::create( builder, loc, pCkptsType, pSumCkpts, pSum, ValueRange{ckptIdxMax, zeroI64}); - enzyme::YieldOp::create(builder, loc, + impulse::YieldOp::create(builder, loc, ValueRange{updatedPCkpts, updatedPSumCkpts}); } { Block *falseBranch = builder.createBlock(&ifOp.getFalseBranch()); builder.setInsertionPointToStart(falseBranch); - enzyme::YieldOp::create(builder, loc, ValueRange{pCkpts, pSumCkpts}); + impulse::YieldOp::create(builder, loc, ValueRange{pCkpts, pSumCkpts}); } builder.setInsertionPointAfter(ifOp); @@ -1617,7 +1618,7 @@ MCMC::updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, return {finalPCkpts, finalPSumCkpts}; } -DualAveragingState MCMC::initDualAveraging(OpBuilder &builder, Location loc, +DualAveragingState impulse::initDualAveraging(OpBuilder &builder, Location loc, Value stepSize) { auto stepSizeType = cast(stepSize.getType()); auto elemType = stepSizeType.getElementType(); @@ -1650,7 +1651,7 @@ DualAveragingState MCMC::initDualAveraging(OpBuilder &builder, Location loc, } DualAveragingState -MCMC::updateDualAveraging(OpBuilder &builder, Location loc, +impulse::updateDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, Value acceptProb, const DualAveragingConfig &config) { // Dual Averaging update: @@ -1742,14 +1743,14 @@ MCMC::updateDualAveraging(OpBuilder &builder, Location loc, }; } -Value MCMC::getStepSizeFromDualAveraging(OpBuilder &builder, Location loc, +Value impulse::getStepSizeFromDualAveraging(OpBuilder &builder, Location loc, const DualAveragingState &state, bool final) { Value logStepSize = final ? state.log_step_size_avg : state.log_step_size; return math::ExpOp::create(builder, loc, logStepSize); } -WelfordState MCMC::initWelford(OpBuilder &builder, Location loc, +WelfordState impulse::initWelford(OpBuilder &builder, Location loc, int64_t positionSize, bool diagonal) { auto elemType = builder.getF64Type(); auto i64TensorType = RankedTensorType::get({}, builder.getI64Type()); @@ -1780,7 +1781,7 @@ WelfordState MCMC::initWelford(OpBuilder &builder, Location loc, return {mean, m2, n}; } -WelfordState MCMC::updateWelford(OpBuilder &builder, Location loc, +WelfordState impulse::updateWelford(OpBuilder &builder, Location loc, const WelfordState &state, Value sample, const WelfordConfig &config) { // Algorithm: @@ -1803,7 +1804,7 @@ WelfordState MCMC::updateWelford(OpBuilder &builder, Location loc, auto scalarType = RankedTensorType::get({}, elemType); Value nFloat = arith::SIToFPOp::create(builder, loc, scalarType, nNew); - Value nBroadcast = enzyme::BroadcastOp::create(builder, loc, sampleType, + Value nBroadcast = BroadcastOp::create(builder, loc, sampleType, nFloat, sampleType.getShape()); Value deltaPre = arith::SubFOp::create(builder, loc, sample, state.mean); @@ -1821,7 +1822,7 @@ WelfordState MCMC::updateWelford(OpBuilder &builder, Location loc, m2New = arith::AddFOp::create(builder, loc, state.m2, product); } else { // Dense auto m2Type = cast(state.m2.getType()); - Value outerProduct = enzyme::DotOp::create( + Value outerProduct = impulse::DotOp::create( builder, loc, m2Type, deltaPost, deltaPre, /*lhs_batching_dimensions=*/builder.getDenseI64ArrayAttr({}), /*rhs_batching_dimensions=*/builder.getDenseI64ArrayAttr({}), @@ -1833,7 +1834,7 @@ WelfordState MCMC::updateWelford(OpBuilder &builder, Location loc, return {meanNew, m2New, nNew}; } -Value MCMC::finalizeWelford(OpBuilder &builder, Location loc, +Value impulse::finalizeWelford(OpBuilder &builder, Location loc, const WelfordState &state, const WelfordConfig &config) { // Compute sample covariance: cov = m2 / (n - 1) @@ -1849,7 +1850,7 @@ Value MCMC::finalizeWelford(OpBuilder &builder, Location loc, Value nMinus1Float = arith::SIToFPOp::create(builder, loc, scalarType, nMinus1); - Value nMinus1Bcast = enzyme::BroadcastOp::create( + Value nMinus1Bcast = BroadcastOp::create( builder, loc, m2Type, nMinus1Float, m2Type.getShape()); Value cov = arith::DivFOp::create(builder, loc, state.m2, nMinus1Bcast); @@ -1869,7 +1870,7 @@ Value MCMC::finalizeWelford(OpBuilder &builder, Location loc, Value nPlusFive = arith::AddFOp::create(builder, loc, nFloat, fiveConst); Value scale = arith::DivFOp::create(builder, loc, nFloat, nPlusFive); - Value scaleBcast = enzyme::BroadcastOp::create(builder, loc, m2Type, scale, + Value scaleBcast = BroadcastOp::create(builder, loc, m2Type, scale, m2Type.getShape()); Value scaledCov = arith::MulFOp::create(builder, loc, scaleBcast, cov); @@ -1881,12 +1882,12 @@ Value MCMC::finalizeWelford(OpBuilder &builder, Location loc, arith::DivFOp::create(builder, loc, shrinkageBaseConst, nPlusFive); if (config.diagonal) { - Value shrinkageBcast = enzyme::BroadcastOp::create( + Value shrinkageBcast = BroadcastOp::create( builder, loc, m2Type, shrinkage, m2Type.getShape()); cov = arith::AddFOp::create(builder, loc, scaledCov, shrinkageBcast); } else { Value identity = createIdentityMatrix(builder, loc, m2Type); - Value shrinkageBcast = enzyme::BroadcastOp::create( + Value shrinkageBcast = BroadcastOp::create( builder, loc, m2Type, shrinkage, m2Type.getShape()); Value shrinkageI = arith::MulFOp::create(builder, loc, shrinkageBcast, identity); @@ -1897,7 +1898,7 @@ Value MCMC::finalizeWelford(OpBuilder &builder, Location loc, return cov; } -SmallVector MCMC::buildAdaptationSchedule(int64_t numSteps) { +SmallVector impulse::buildAdaptationSchedule(int64_t numSteps) { // |<-- start buffer -->|<-- middle windows (doubling) -->|<-- end buffer -->| // | (no mass) | (collect + adapt mass) | (no mass) | // | step size only | step size + mass matrix | step size only | @@ -1948,12 +1949,12 @@ SmallVector MCMC::buildAdaptationSchedule(int64_t numSteps) { return schedule; } -Value MCMC::unconstrainPosition(OpBuilder &builder, Location loc, +Value impulse::unconstrainPosition(OpBuilder &builder, Location loc, Value constrained, ArrayRef supports) { bool hasConstraints = false; for (const auto &info : supports) { - if (info.support && info.support.getKind() != enzyme::SupportKind::REAL) { + if (info.support && info.support.getKind() != impulse::SupportKind::REAL) { hasConstraints = true; break; } @@ -1967,18 +1968,18 @@ Value MCMC::unconstrainPosition(OpBuilder &builder, Location loc, auto positionType1D = RankedTensorType::get({positionSize}, elemType); Value constrained1D = - enzyme::ReshapeOp::create(builder, loc, positionType1D, constrained); + impulse::ReshapeOp::create(builder, loc, positionType1D, constrained); SmallVector slices; for (const auto &info : supports) { auto sliceType = RankedTensorType::get({info.size}, elemType); - auto slice = enzyme::SliceOp::create( + auto slice = impulse::SliceOp::create( builder, loc, sliceType, constrained1D, builder.getDenseI64ArrayAttr({info.offset}), builder.getDenseI64ArrayAttr({info.offset + info.size}), builder.getDenseI64ArrayAttr({1})); Value unconstrainedSlice; - if (info.support && info.support.getKind() != enzyme::SupportKind::REAL) { + if (info.support && info.support.getKind() != impulse::SupportKind::REAL) { unconstrainedSlice = transforms::unconstrain(builder, loc, slice, info.support); } else { @@ -2016,15 +2017,15 @@ Value MCMC::unconstrainPosition(OpBuilder &builder, Location loc, DenseElementsAttr::get(i64ScalarType, builder.getI64IntegerAttr(offset + j))); - auto elemSliced = enzyme::DynamicSliceOp::create( + auto elemSliced = impulse::DynamicSliceOp::create( builder, loc, elemType1DSlice, slices[i], ValueRange{elemIdx}, builder.getDenseI64ArrayAttr({1})); auto elem = - enzyme::ReshapeOp::create(builder, loc, elemType0D, elemSliced); + impulse::ReshapeOp::create(builder, loc, elemType0D, elemSliced); - auto resultSliced = enzyme::ReshapeOp::create( + auto resultSliced = impulse::ReshapeOp::create( builder, loc, RankedTensorType::get({1}, elemType), elem); - result1D = enzyme::DynamicUpdateSliceOp::create( + result1D = impulse::DynamicUpdateSliceOp::create( builder, loc, resultType1D, result1D, resultSliced, ValueRange{resultIdx}); } @@ -2034,15 +2035,15 @@ Value MCMC::unconstrainPosition(OpBuilder &builder, Location loc, } auto resultType2D = RankedTensorType::get({1, positionSize}, elemType); - return enzyme::ReshapeOp::create(builder, loc, resultType2D, result1D); + return impulse::ReshapeOp::create(builder, loc, resultType2D, result1D); } -Value MCMC::constrainPosition(OpBuilder &builder, Location loc, +Value impulse::constrainPosition(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef supports) { bool hasConstraints = false; for (const auto &info : supports) { - if (info.support && info.support.getKind() != enzyme::SupportKind::REAL) { + if (info.support && info.support.getKind() != impulse::SupportKind::REAL) { hasConstraints = true; break; } @@ -2056,19 +2057,19 @@ Value MCMC::constrainPosition(OpBuilder &builder, Location loc, auto positionType1D = RankedTensorType::get({positionSize}, elemType); Value unconstrained1D = - enzyme::ReshapeOp::create(builder, loc, positionType1D, unconstrained); + impulse::ReshapeOp::create(builder, loc, positionType1D, unconstrained); SmallVector slices; for (const auto &info : supports) { auto sliceType = RankedTensorType::get({info.size}, elemType); - auto slice = enzyme::SliceOp::create( + auto slice = impulse::SliceOp::create( builder, loc, sliceType, unconstrained1D, builder.getDenseI64ArrayAttr({info.offset}), builder.getDenseI64ArrayAttr({info.offset + info.size}), builder.getDenseI64ArrayAttr({1})); Value constrainedSlice; - if (info.support && info.support.getKind() != enzyme::SupportKind::REAL) { + if (info.support && info.support.getKind() != impulse::SupportKind::REAL) { constrainedSlice = transforms::constrain(builder, loc, slice, info.support); } else { @@ -2106,15 +2107,15 @@ Value MCMC::constrainPosition(OpBuilder &builder, Location loc, DenseElementsAttr::get(i64ScalarType, builder.getI64IntegerAttr(offset + j))); - auto elemSliced = enzyme::DynamicSliceOp::create( + auto elemSliced = impulse::DynamicSliceOp::create( builder, loc, elemType1DSlice, slices[i], ValueRange{elemIdx}, builder.getDenseI64ArrayAttr({1})); auto elem = - enzyme::ReshapeOp::create(builder, loc, elemType0D, elemSliced); + impulse::ReshapeOp::create(builder, loc, elemType0D, elemSliced); auto resultSliced = - enzyme::ReshapeOp::create(builder, loc, elemType1DSlice, elem); - result1D = enzyme::DynamicUpdateSliceOp::create( + impulse::ReshapeOp::create(builder, loc, elemType1DSlice, elem); + result1D = impulse::DynamicUpdateSliceOp::create( builder, loc, resultType1D, result1D, resultSliced, ValueRange{resultIdx}); } @@ -2124,10 +2125,10 @@ Value MCMC::constrainPosition(OpBuilder &builder, Location loc, } auto resultType2D = RankedTensorType::get({1, positionSize}, elemType); - return enzyme::ReshapeOp::create(builder, loc, resultType2D, result1D); + return impulse::ReshapeOp::create(builder, loc, resultType2D, result1D); } -Value MCMC::computeTotalJacobianCorrection(OpBuilder &builder, Location loc, +Value impulse::computeTotalJacobianCorrection(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef supports) { auto inputType = cast(unconstrained.getType()); @@ -2144,14 +2145,14 @@ Value MCMC::computeTotalJacobianCorrection(OpBuilder &builder, Location loc, int64_t positionSize = inputType.getShape()[1]; auto positionType1D = RankedTensorType::get({positionSize}, elemType); Value unconstrained1D = - enzyme::ReshapeOp::create(builder, loc, positionType1D, unconstrained); + impulse::ReshapeOp::create(builder, loc, positionType1D, unconstrained); for (const auto &info : supports) { - if (!info.support || info.support.getKind() == enzyme::SupportKind::REAL) + if (!info.support || info.support.getKind() == impulse::SupportKind::REAL) continue; auto sliceType = RankedTensorType::get({info.size}, elemType); - auto slice = enzyme::SliceOp::create( + auto slice = impulse::SliceOp::create( builder, loc, sliceType, unconstrained1D, builder.getDenseI64ArrayAttr({info.offset}), builder.getDenseI64ArrayAttr({info.offset + info.size}), diff --git a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h index 5e0119dcfa55..d79d0a9118cc 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h @@ -11,6 +11,7 @@ #ifndef ENZYME_MLIR_INTERFACES_HMC_UTILS_H #define ENZYME_MLIR_INTERFACES_HMC_UTILS_H +#include "Dialect/Impulse/Impulse.h" #include "Dialect/Ops.h" #include "Interfaces/TransformUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -22,17 +23,16 @@ #include "mlir/IR/Value.h" namespace mlir { -namespace enzyme { -namespace MCMC { +namespace impulse { struct SupportInfo { int64_t offset; int64_t traceOffset; int64_t size; - enzyme::SupportAttr support; + impulse::SupportAttr support; SupportInfo(int64_t offset, int64_t traceOffset, int64_t size, - enzyme::SupportAttr support) + impulse::SupportAttr support) : offset(offset), traceOffset(traceOffset), size(size), support(support) { } }; @@ -156,7 +156,7 @@ struct HMCContext { bool hasConstrainedSupports() const { for (const auto &info : supports) { - if (info.support && info.support.getKind() != enzyme::SupportKind::REAL) + if (info.support && info.support.getKind() != impulse::SupportKind::REAL) return true; } return false; @@ -227,7 +227,7 @@ struct SubtreeBuildResult { }; /// Conditionally dump a value for debugging. -/// Emits an enzyme::DumpOp if `debugDump` is true; otherwise has no effect. +/// Emits an impulse::DumpOp if `debugDump` is true; otherwise has no effect. Value conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump); @@ -460,8 +460,7 @@ Value constrainPosition(OpBuilder &builder, Location loc, Value unconstrained, Value computeTotalJacobianCorrection(OpBuilder &builder, Location loc, Value unconstrained, ArrayRef supports); -} // namespace MCMC -} // namespace enzyme +} // namespace impulse } // namespace mlir #endif // ENZYME_MLIR_INTERFACES_HMC_UTILS_H diff --git a/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.cpp index 7f77d039dbfc..41be8e2dcfb2 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.cpp @@ -9,6 +9,7 @@ #include "TransformUtils.h" +#include "Dialect/Impulse/Impulse.h" #include "Dialect/Ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -42,53 +43,53 @@ Value transforms::createLogSigmoid(OpBuilder &builder, Location loc, Value x) { builder, loc, xType, DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, 0.0))); auto softplusNegX = - enzyme::LogAddExpOp::create(builder, loc, xType, negX, zeroConst); + impulse::LogAddExpOp::create(builder, loc, xType, negX, zeroConst); return arith::NegFOp::create(builder, loc, softplusNegX); } int64_t transforms::getUnconstrainedSize(int64_t constrainedSize, - SupportKind kind) { + impulse::SupportKind kind) { switch (kind) { - case SupportKind::REAL: - case SupportKind::POSITIVE: - case SupportKind::UNIT_INTERVAL: - case SupportKind::INTERVAL: - case SupportKind::GREATER_THAN: - case SupportKind::LESS_THAN: + case impulse::SupportKind::REAL: + case impulse::SupportKind::POSITIVE: + case impulse::SupportKind::UNIT_INTERVAL: + case impulse::SupportKind::INTERVAL: + case impulse::SupportKind::GREATER_THAN: + case impulse::SupportKind::LESS_THAN: return constrainedSize; } llvm_unreachable("Unknown SupportKind"); } int64_t transforms::getConstrainedSize(int64_t unconstrainedSize, - SupportKind kind) { + impulse::SupportKind kind) { switch (kind) { - case SupportKind::REAL: - case SupportKind::POSITIVE: - case SupportKind::UNIT_INTERVAL: - case SupportKind::INTERVAL: - case SupportKind::GREATER_THAN: - case SupportKind::LESS_THAN: + case impulse::SupportKind::REAL: + case impulse::SupportKind::POSITIVE: + case impulse::SupportKind::UNIT_INTERVAL: + case impulse::SupportKind::INTERVAL: + case impulse::SupportKind::GREATER_THAN: + case impulse::SupportKind::LESS_THAN: return unconstrainedSize; } llvm_unreachable("Unknown SupportKind"); } Value transforms::unconstrain(OpBuilder &builder, Location loc, - Value constrained, SupportAttr support) { + Value constrained, impulse::SupportAttr support) { auto kind = support.getKind(); auto xType = cast(constrained.getType()); auto elemType = xType.getElementType(); switch (kind) { - case SupportKind::REAL: + case impulse::SupportKind::REAL: // Identity return constrained; - case SupportKind::POSITIVE: + case impulse::SupportKind::POSITIVE: return math::LogOp::create(builder, loc, constrained); - case SupportKind::UNIT_INTERVAL: + case impulse::SupportKind::UNIT_INTERVAL: return createLogit(builder, loc, constrained); - case SupportKind::INTERVAL: { + case impulse::SupportKind::INTERVAL: { // z = logit((x - a) / (b - a)) auto lowerAttr = support.getLowerBound(); auto upperAttr = support.getUpperBound(); @@ -110,7 +111,7 @@ Value transforms::unconstrain(OpBuilder &builder, Location loc, auto normalized = arith::DivFOp::create(builder, loc, shifted, scaleConst); return createLogit(builder, loc, normalized); } - case SupportKind::GREATER_THAN: { + case impulse::SupportKind::GREATER_THAN: { // z = log(x - lower) auto lowerAttr = support.getLowerBound(); if (!lowerAttr) { @@ -124,7 +125,7 @@ Value transforms::unconstrain(OpBuilder &builder, Location loc, auto shifted = arith::SubFOp::create(builder, loc, constrained, lowerConst); return math::LogOp::create(builder, loc, shifted); } - case SupportKind::LESS_THAN: { + case impulse::SupportKind::LESS_THAN: { // z = log(upper - x) auto upperAttr = support.getUpperBound(); if (!upperAttr) { @@ -143,22 +144,22 @@ Value transforms::unconstrain(OpBuilder &builder, Location loc, } Value transforms::constrain(OpBuilder &builder, Location loc, - Value unconstrained, SupportAttr support) { + Value unconstrained, impulse::SupportAttr support) { auto kind = support.getKind(); auto zType = cast(unconstrained.getType()); auto elemType = zType.getElementType(); switch (kind) { - case SupportKind::REAL: + case impulse::SupportKind::REAL: // Identity return unconstrained; - case SupportKind::POSITIVE: + case impulse::SupportKind::POSITIVE: return math::ExpOp::create(builder, loc, unconstrained); - case SupportKind::UNIT_INTERVAL: + case impulse::SupportKind::UNIT_INTERVAL: // x = sigmoid(z) - return enzyme::LogisticOp::create(builder, loc, unconstrained.getType(), - unconstrained); - case SupportKind::INTERVAL: { + return impulse::LogisticOp::create(builder, loc, unconstrained.getType(), + unconstrained); + case impulse::SupportKind::INTERVAL: { // x = a + (b - a) * sigmoid(z) auto lowerAttr = support.getLowerBound(); auto upperAttr = support.getUpperBound(); @@ -176,12 +177,12 @@ Value transforms::constrain(OpBuilder &builder, Location loc, DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, upper - lower))); - auto sigmoid = enzyme::LogisticOp::create( + auto sigmoid = impulse::LogisticOp::create( builder, loc, unconstrained.getType(), unconstrained); auto scaled = arith::MulFOp::create(builder, loc, scaleConst, sigmoid); return arith::AddFOp::create(builder, loc, lowerConst, scaled); } - case SupportKind::GREATER_THAN: { + case impulse::SupportKind::GREATER_THAN: { // x = lower + exp(z) auto lowerAttr = support.getLowerBound(); if (!lowerAttr) { @@ -195,7 +196,7 @@ Value transforms::constrain(OpBuilder &builder, Location loc, auto expZ = math::ExpOp::create(builder, loc, unconstrained); return arith::AddFOp::create(builder, loc, lowerConst, expZ); } - case SupportKind::LESS_THAN: { + case impulse::SupportKind::LESS_THAN: { // x = upper - exp(z) auto upperAttr = support.getUpperBound(); if (!upperAttr) { @@ -215,32 +216,33 @@ Value transforms::constrain(OpBuilder &builder, Location loc, } Value transforms::logAbsDetJacobian(OpBuilder &builder, Location loc, - Value unconstrained, SupportAttr support) { + Value unconstrained, + impulse::SupportAttr support) { auto kind = support.getKind(); auto zType = cast(unconstrained.getType()); auto elemType = zType.getElementType(); auto scalarType = RankedTensorType::get({}, elemType); switch (kind) { - case SupportKind::REAL: { + case impulse::SupportKind::REAL: { // Identity: log|det(I)| = 0 return arith::ConstantOp::create( builder, loc, scalarType, DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 0.0))); } - case SupportKind::POSITIVE: { + case impulse::SupportKind::POSITIVE: { // x = exp(z), dx/dz = exp(z) // log|det(J)| = sum(log|dx_i/dz_i|) = sum(z) auto ones = arith::ConstantOp::create( builder, loc, zType, DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0))); - return enzyme::DotOp::create( + return impulse::DotOp::create( builder, loc, scalarType, unconstrained, ones, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0})); } - case SupportKind::UNIT_INTERVAL: { + case impulse::SupportKind::UNIT_INTERVAL: { // x = sigmoid(z), dx/dz = sigmoid(z) * (1 - sigmoid(z)) // log|det(J)| = sum(log(sigmoid(z)) + log(1 - sigmoid(z))) // = sum(log_sigmoid(z) + log_sigmoid(-z)) @@ -251,12 +253,12 @@ Value transforms::logAbsDetJacobian(OpBuilder &builder, Location loc, auto ones = arith::ConstantOp::create( builder, loc, zType, DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0))); - return enzyme::DotOp::create( + return impulse::DotOp::create( builder, loc, scalarType, logProduct, ones, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0})); } - case SupportKind::INTERVAL: { + case impulse::SupportKind::INTERVAL: { // log|det(J)| = sum(log_sigmoid(z) + log(1 - sigmoid(z))) + n*log(scale) auto lowerAttr = support.getLowerBound(); auto upperAttr = support.getUpperBound(); @@ -272,7 +274,7 @@ Value transforms::logAbsDetJacobian(OpBuilder &builder, Location loc, auto ones = arith::ConstantOp::create( builder, loc, zType, DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0))); - auto sumLogProduct = enzyme::DotOp::create( + auto sumLogProduct = impulse::DotOp::create( builder, loc, scalarType, logProduct, ones, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0})); @@ -285,13 +287,13 @@ Value transforms::logAbsDetJacobian(OpBuilder &builder, Location loc, builder.getFloatAttr(elemType, logScaleTerm))); return arith::AddFOp::create(builder, loc, sumLogProduct, logScaleConst); } - case SupportKind::GREATER_THAN: - case SupportKind::LESS_THAN: { + case impulse::SupportKind::GREATER_THAN: + case impulse::SupportKind::LESS_THAN: { // log|det(J)| = sum(z) auto ones = arith::ConstantOp::create( builder, loc, zType, DenseElementsAttr::get(zType, builder.getFloatAttr(elemType, 1.0))); - return enzyme::DotOp::create( + return impulse::DotOp::create( builder, loc, scalarType, unconstrained, ones, builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({}), builder.getDenseI64ArrayAttr({0}), builder.getDenseI64ArrayAttr({0})); diff --git a/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.h b/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.h index 4b42c74f417f..e07257126139 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.h @@ -11,6 +11,7 @@ #ifndef ENZYME_MLIR_INTERFACES_TRANSFORM_UTILS_H #define ENZYME_MLIR_INTERFACES_TRANSFORM_UTILS_H +#include "Dialect/Impulse/Impulse.h" #include "Dialect/Ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -23,22 +24,22 @@ namespace enzyme { namespace transforms { /// Get the unconstrained size given a constrained size and support kind. -int64_t getUnconstrainedSize(int64_t constrainedSize, SupportKind kind); +int64_t getUnconstrainedSize(int64_t constrainedSize, impulse::SupportKind kind); /// Get the constrained size given an unconstrained size and support kind. -int64_t getConstrainedSize(int64_t unconstrainedSize, SupportKind kind); +int64_t getConstrainedSize(int64_t unconstrainedSize, impulse::SupportKind kind); /// Transform from constrained to unconstrained space. Value unconstrain(OpBuilder &builder, Location loc, Value constrained, - SupportAttr support); + impulse::SupportAttr support); /// Transform from unconstrained to constrained space. Value constrain(OpBuilder &builder, Location loc, Value unconstrained, - SupportAttr support); + impulse::SupportAttr support); /// Compute log |det J| of the transform from unconstrained to constrained. Value logAbsDetJacobian(OpBuilder &builder, Location loc, Value unconstrained, - SupportAttr support); + impulse::SupportAttr support); Value createLogit(OpBuilder &builder, Location loc, Value x); Value createLogSigmoid(OpBuilder &builder, Location loc, Value x); diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 53ea74f11a9c..b0559588935e 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms DEPENDS MLIREnzymePassIncGen MLIREnzymeEnumsIncGen + MLIRImpulseOpsIncGen LINK_LIBS PUBLIC MLIRAffineDialect @@ -55,4 +56,5 @@ add_mlir_dialect_library(MLIREnzymeTransforms MLIRTransformUtils MLIREnzymeAutoDiffInterface + MLIRImpulse ) diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index d4b9d8052b2a..862cb7bf3ac9 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -44,6 +44,7 @@ def ProbProgPass : Pass<"probprog"> { "arith::ArithDialect", "complex::ComplexDialect", "enzyme::EnzymeDialect", + "impulse::ImpulseDialect", ]; let options = [ Option< diff --git a/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp index bcdae1786b98..2a5e7e4a87ff 100644 --- a/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp @@ -10,6 +10,7 @@ // This file implements a pass to handle probabilistic programming operations //===----------------------------------------------------------------------===// +#include "Dialect/Impulse/Impulse.h" #include "Dialect/Ops.h" #include "Interfaces/HMCUtils.h" #include "Interfaces/ProbProgUtils.h" @@ -30,8 +31,7 @@ using namespace mlir; using namespace mlir::enzyme; -using namespace enzyme; -using namespace enzyme::MCMC; +using namespace mlir::impulse; namespace mlir { namespace enzyme { @@ -52,25 +52,25 @@ static int64_t computeTensorElementCount(RankedTensorType tensorType) { return elemCount; } -using SampleOpMap = DenseMap; +using SampleOpMap = DenseMap; static SampleOpMap buildSampleOpMap(FunctionOpInterface fn) { SampleOpMap map; - fn.walk([&](enzyme::SampleOp sampleOp) { + fn.walk([&](impulse::SampleOp sampleOp) { if (auto symbol = sampleOp.getSymbolAttr()) map[symbol] = sampleOp; }); return map; } -static enzyme::SampleOp findSampleBySymbol(const SampleOpMap &map, - Attribute targetSymbol) { +static impulse::SampleOp findSampleBySymbol(const SampleOpMap &map, + Attribute targetSymbol) { auto it = map.find(targetSymbol); return it != map.end() ? it->second : nullptr; } static int64_t computeSampleElementCount(Operation *op, - enzyme::SampleOp sampleOp) { + impulse::SampleOp sampleOp) { int64_t totalCount = 0; for (unsigned i = 1; i < sampleOp.getNumResults(); ++i) { auto resultType = sampleOp.getResult(i).getType(); @@ -180,12 +180,12 @@ computeOffsetForSampleInSelection(Operation *op, FunctionOpInterface fn, return -1; } -static SmallVector +static SmallVector collectSupportInfoForSelection(Operation *op, FunctionOpInterface fn, ArrayAttr selection, ArrayAttr allAddresses, SymbolTableCollection &symbolTable) { auto sampleMap = buildSampleOpMap(fn); - SmallVector supports; + SmallVector supports; int64_t currentPositionOffset = 0; for (auto addr : selection) { @@ -279,16 +279,17 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { pm.getDependentDialects(registry); } - registry.insert(); + registry + .insert(); } struct LowerUntracedCallPattern - : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(enzyme::UntracedCallOp CI, + LogicalResult matchAndRewrite(impulse::UntracedCallOp CI, PatternRewriter &rewriter) const override { SymbolTableCollection symbolTable; @@ -304,7 +305,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { FunctionOpInterface NewF = putils->newFunc; SmallVector toErase; - NewF.walk([&](enzyme::SampleOp sampleOp) { + NewF.walk([&](impulse::SampleOp sampleOp) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sampleOp); @@ -336,10 +337,10 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { }; struct LowerSimulatePattern - : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(enzyme::SimulateOp CI, + LogicalResult matchAndRewrite(impulse::SimulateOp CI, PatternRewriter &rewriter) const override { SymbolTableCollection symbolTable; @@ -385,7 +386,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { int64_t currentOffset = 0; SmallVector toErase; - auto result = NewF.walk([&](enzyme::SampleOp sampleOp) -> WalkResult { + auto result = NewF.walk([&](impulse::SampleOp sampleOp) -> WalkResult { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sampleOp); @@ -455,7 +456,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { auto flatSampleType = RankedTensorType::get( {1, numElements}, sampleType.getElementType()); - auto flatSample = enzyme::ReshapeOp::create( + auto flatSample = impulse::ReshapeOp::create( rewriter, sampleOp.getLoc(), flatSampleType, sampleValue); auto i64S = RankedTensorType::get({}, rewriter.getI64Type()); auto row0 = arith::ConstantOp::create( @@ -465,7 +466,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, sampleOp.getLoc(), i64S, DenseElementsAttr::get( i64S, rewriter.getI64IntegerAttr(currentOffset))); - currTrace = enzyme::DynamicUpdateSliceOp::create( + currTrace = impulse::DynamicUpdateSliceOp::create( rewriter, sampleOp.getLoc(), traceType, currTrace, flatSample, ValueRange{row0, colOff}) .getResult(); @@ -514,7 +515,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { for (auto t : genFn.getResultTypes()) simResultTypes.push_back(t); - auto nestedSimulate = enzyme::SimulateOp::create( + auto nestedSimulate = impulse::SimulateOp::create( rewriter, sampleOp.getLoc(), simResultTypes, sampleOp.getFnAttr(), sampleOp.getInputs(), subSelection); auto subTrace = nestedSimulate.getTrace(); @@ -538,7 +539,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, sampleOp.getLoc(), i64S, DenseElementsAttr::get( i64S, rewriter.getI64IntegerAttr(mergeOffset))); - currTrace = enzyme::DynamicUpdateSliceOp::create( + currTrace = impulse::DynamicUpdateSliceOp::create( rewriter, sampleOp.getLoc(), traceType, currTrace, subTrace, ValueRange{row0, colOff}) .getResult(); @@ -591,14 +592,14 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { } }; - struct LowerMCMCPattern : public mlir::OpRewritePattern { + struct LowerMCMCPattern : public mlir::OpRewritePattern { bool debugDump; LowerMCMCPattern(MLIRContext *context, bool debugDump, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), debugDump(debugDump) {} - LogicalResult matchAndRewrite(enzyme::MCMCOp mcmcOp, + LogicalResult matchAndRewrite(impulse::InferOp mcmcOp, PatternRewriter &rewriter) const override { SymbolTableCollection symbolTable; @@ -930,8 +931,8 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { warmupInitArgs.push_back(windowIdx); auto warmupLoop = - enzyme::ForLoopOp::create(rewriter, loc, warmupLoopTypes, c0, - numWarmupConst, c1, warmupInitArgs); + impulse::ForOp::create(rewriter, loc, warmupLoopTypes, c0, + numWarmupConst, c1, warmupInitArgs); Block *warmupBody = rewriter.createBlock(&warmupLoop.getRegion()); warmupBody->addArgument(i64TensorType, loc); // iteration index t @@ -992,7 +993,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { // Use log_step_size_avg at last iteration auto isLastIter = arith::CmpIOp::create( rewriter, loc, arith::CmpIPredicate::eq, iterT, lastIterConst); - Value adaptedStepSizeInLoop = enzyme::SelectOp::create( + Value adaptedStepSizeInLoop = impulse::SelectOp::create( rewriter, loc, scalarType, isLastIter, finalStepSizeFromDA, currentStepSizeFromDA); @@ -1026,17 +1027,17 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { if (adaptMassMatrix) { auto sampleType1D = RankedTensorType::get({positionSize}, elemType); Value sample1D = - enzyme::ReshapeOp::create(rewriter, loc, sampleType1D, sample.q); + impulse::ReshapeOp::create(rewriter, loc, sampleType1D, sample.q); WelfordState updatedWelfordAfterSample = updateWelford( rewriter, loc, welfordStateLoop, sample1D, welfordConfig); - conditionalWelford.mean = enzyme::SelectOp::create( + conditionalWelford.mean = impulse::SelectOp::create( rewriter, loc, welfordStateLoop.mean.getType(), isMiddleWindow, updatedWelfordAfterSample.mean, welfordStateLoop.mean); - conditionalWelford.m2 = enzyme::SelectOp::create( + conditionalWelford.m2 = impulse::SelectOp::create( rewriter, loc, welfordStateLoop.m2.getType(), isMiddleWindow, updatedWelfordAfterSample.m2, welfordStateLoop.m2); - conditionalWelford.n = enzyme::SelectOp::create( + conditionalWelford.n = impulse::SelectOp::create( rewriter, loc, welfordStateLoop.n.getType(), isMiddleWindow, updatedWelfordAfterSample.n, welfordStateLoop.n); } @@ -1064,8 +1065,8 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { Value newWindowIdx = arith::AddIOp::create(rewriter, loc, windowIdxLoop, c1); Value windowIdxAfterIncrement = - enzyme::SelectOp::create(rewriter, loc, i64TensorType, atWindowEnd, - newWindowIdx, windowIdxLoop); + impulse::SelectOp::create(rewriter, loc, i64TensorType, atWindowEnd, + newWindowIdx, windowIdxLoop); auto atMiddleWindowEnd = arith::AndIOp::create(rewriter, loc, atWindowEnd, isMiddleWindow); @@ -1087,8 +1088,8 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { for (Type t : updatedDaState.getTypes()) ifResultTypes.push_back(t); - auto ifOp = enzyme::IfOp::create(rewriter, loc, ifResultTypes, - atMiddleWindowEnd); + auto ifOp = impulse::IfOp::create(rewriter, loc, ifResultTypes, + atMiddleWindowEnd); { Block *trueBranch = rewriter.createBlock(&ifOp.getTrueBranch()); @@ -1124,7 +1125,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { trueYieldValues.push_back(v); } - enzyme::YieldOp::create(rewriter, loc, trueYieldValues); + impulse::YieldOp::create(rewriter, loc, trueYieldValues); } { @@ -1142,7 +1143,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { for (auto v : updatedDaState.toValues()) falseYieldValues.push_back(v); - enzyme::YieldOp::create(rewriter, loc, falseYieldValues); + impulse::YieldOp::create(rewriter, loc, falseYieldValues); } rewriter.setInsertionPointAfter(ifOp); @@ -1174,7 +1175,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { } warmupYieldValues.push_back(windowIdxAfterIncrement); - enzyme::YieldOp::create(rewriter, loc, warmupYieldValues); + impulse::YieldOp::create(rewriter, loc, warmupYieldValues); rewriter.setInsertionPointAfter(warmupLoop); @@ -1236,7 +1237,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { SmallVector loopResultTypes = { positionType, positionType, scalarType, currentRng.getType(), samplesBufferType, acceptedBufferType}; - auto forLoopOp = enzyme::ForLoopOp::create( + auto forLoopOp = impulse::ForOp::create( rewriter, loc, loopResultTypes, c0, numSamplesConst, c1, ValueRange{currentQ, currentGrad, currentU, currentRng, samplesBuffer, acceptedBuffer}); @@ -1262,7 +1263,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { auto sample = runSampleStepWithStepSize(rewriter, loc, qLoop, gradLoop, ULoop, rngLoop, adaptedStepSize); auto q_constrained = - MCMC::constrainPosition(rewriter, loc, sample.q, supports); + impulse::constrainPosition(rewriter, loc, sample.q, supports); // Storage index: idx = (i - start_idx) / thinning auto iMinusStart = @@ -1284,27 +1285,27 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { auto zeroCol = arith::ConstantOp::create( rewriter, loc, i64TensorType, DenseElementsAttr::get(i64TensorType, rewriter.getI64IntegerAttr(0))); - auto updatedSamplesBuffer = enzyme::DynamicUpdateSliceOp::create( + auto updatedSamplesBuffer = impulse::DynamicUpdateSliceOp::create( rewriter, loc, samplesBufferType, samplesBufferLoop, q_constrained, ValueRange{storageIdx, zeroCol}); - auto selectedSamplesBuffer = enzyme::SelectOp::create( + auto selectedSamplesBuffer = impulse::SelectOp::create( rewriter, loc, samplesBufferType, shouldStore, updatedSamplesBuffer, samplesBufferLoop); - auto accepted1D = enzyme::ReshapeOp::create( + auto accepted1D = impulse::ReshapeOp::create( rewriter, loc, RankedTensorType::get({1}, rewriter.getI1Type()), sample.accepted); - auto updatedAcceptedBuffer = enzyme::DynamicUpdateSliceOp::create( + auto updatedAcceptedBuffer = impulse::DynamicUpdateSliceOp::create( rewriter, loc, acceptedBufferType, acceptedBufferLoop, accepted1D, ValueRange{storageIdx}); - auto selectedAcceptedBuffer = enzyme::SelectOp::create( + auto selectedAcceptedBuffer = impulse::SelectOp::create( rewriter, loc, acceptedBufferType, shouldStore, updatedAcceptedBuffer, acceptedBufferLoop); - enzyme::YieldOp::create(rewriter, loc, - ValueRange{sample.q, sample.grad, sample.U, - sample.rng, selectedSamplesBuffer, - selectedAcceptedBuffer}); + impulse::YieldOp::create(rewriter, loc, + ValueRange{sample.q, sample.grad, sample.U, + sample.rng, selectedSamplesBuffer, + selectedAcceptedBuffer}); rewriter.setInsertionPointAfter(forLoopOp); Value finalQ = forLoopOp.getResult(0); @@ -1326,10 +1327,10 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { } }; - struct LowerMHPattern : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; + struct LowerMHPattern : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(enzyme::MHOp mhOp, + LogicalResult matchAndRewrite(impulse::MHOp mhOp, PatternRewriter &rewriter) const override { SymbolTableCollection symbolTable; @@ -1371,7 +1372,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { for (auto t : fn.getResultTypes()) regenResultTypes.push_back(t); - auto regenerateOp = rewriter.create( + auto regenerateOp = rewriter.create( loc, /*resultTypes*/ regenResultTypes, /*fn*/ mhOp.getFnAttr(), @@ -1395,11 +1396,11 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { auto oneConst = arith::ConstantOp::create( rewriter, loc, weightType, DenseElementsAttr::get(weightType, 1.0)); - auto randomOp = enzyme::RandomOp::create( + auto randomOp = impulse::RandomOp::create( rewriter, loc, TypeRange{rngStateType, weightType}, newRng, zeroConst, oneConst, - enzyme::RngDistributionAttr::get(rewriter.getContext(), - enzyme::RngDistribution::UNIFORM)); + impulse::RngDistributionAttr::get(rewriter.getContext(), + impulse::RngDistribution::UNIFORM)); auto logRand = math::LogOp::create(rewriter, loc, randomOp.getResult()); Value finalRng = randomOp.getOutputRngState(); @@ -1408,7 +1409,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, loc, arith::CmpFPredicate::OLT, logRand, logAlpha); // 5. Select trace and weight based on acceptance - auto selectedTrace = enzyme::SelectOp::create( + auto selectedTrace = impulse::SelectOp::create( rewriter, loc, traceType, accepted, newTrace, oldTrace); auto selectedWeight = arith::SelectOp::create(rewriter, loc, accepted, newWeight, oldWeight); @@ -1420,10 +1421,10 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { }; struct LowerGeneratePattern - : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(enzyme::GenerateOp CI, + LogicalResult matchAndRewrite(impulse::GenerateOp CI, PatternRewriter &rewriter) const override { SymbolTableCollection symbolTable; @@ -1479,7 +1480,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { int64_t currentTraceOffset = 0; SmallVector toErase; - auto result = NewF.walk([&](enzyme::SampleOp sampleOp) -> WalkResult { + auto result = NewF.walk([&](impulse::SampleOp sampleOp) -> WalkResult { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sampleOp); @@ -1524,13 +1525,13 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { auto sliceType = RankedTensorType::get( {1, numElements}, resultType.getElementType()); - auto sliced = enzyme::SliceOp::create( + auto sliced = impulse::SliceOp::create( rewriter, sampleOp.getLoc(), sliceType, constraint, rewriter.getDenseI64ArrayAttr({0, constrainedOffset}), rewriter.getDenseI64ArrayAttr( {1, constrainedOffset + numElements}), rewriter.getDenseI64ArrayAttr({1, 1})); - auto extracted = enzyme::ReshapeOp::create( + auto extracted = impulse::ReshapeOp::create( rewriter, sampleOp.getLoc(), resultType, sliced); sampledValues[i] = extracted.getResult(); constrainedOffset += numElements; @@ -1625,7 +1626,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { auto flatSampleType = RankedTensorType::get( {1, numElements}, sampleType.getElementType()); - auto flatSample = enzyme::ReshapeOp::create( + auto flatSample = impulse::ReshapeOp::create( rewriter, sampleOp.getLoc(), flatSampleType, sampleValue); auto i64S = RankedTensorType::get({}, rewriter.getI64Type()); auto row0 = arith::ConstantOp::create( @@ -1635,7 +1636,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, sampleOp.getLoc(), i64S, DenseElementsAttr::get( i64S, rewriter.getI64IntegerAttr(currentTraceOffset))); - currTrace = enzyme::DynamicUpdateSliceOp::create( + currTrace = impulse::DynamicUpdateSliceOp::create( rewriter, sampleOp.getLoc(), traceType, currTrace, flatSample, ValueRange{row0, colOff}) .getResult(); @@ -1689,7 +1690,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { sampleOp, fn, CI.getConstrainedAddressesAttr(), sampleOp.getSymbolAttr(), symbolTable); - subConstraint = enzyme::SliceOp::create( + subConstraint = impulse::SliceOp::create( rewriter, sampleOp.getLoc(), subConstraintType, constraint, rewriter.getDenseI64ArrayAttr({0, subConstraintOffset}), rewriter.getDenseI64ArrayAttr( @@ -1711,7 +1712,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { for (auto t : genFn.getResultTypes()) genResultTypes.push_back(t); - auto nestedGenerate = enzyme::GenerateOp::create( + auto nestedGenerate = impulse::GenerateOp::create( rewriter, sampleOp.getLoc(), genResultTypes, sampleOp.getFnAttr(), sampleOp.getInputs(), subConstraint, subSelection, subConstrainedAddrs); @@ -1733,7 +1734,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, sampleOp.getLoc(), i64S, DenseElementsAttr::get( i64S, rewriter.getI64IntegerAttr(mergeOffset))); - currTrace = enzyme::DynamicUpdateSliceOp::create( + currTrace = impulse::DynamicUpdateSliceOp::create( rewriter, sampleOp.getLoc(), traceType, currTrace, subTrace, ValueRange{row0, colOff}) .getResult(); @@ -1789,10 +1790,10 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { }; struct LowerRegeneratePattern - : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; + : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(enzyme::RegenerateOp CI, + LogicalResult matchAndRewrite(impulse::RegenerateOp CI, PatternRewriter &rewriter) const override { SymbolTableCollection symbolTable; @@ -1842,7 +1843,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { int64_t currentTraceOffset = 0; SmallVector toErase; - auto result = NewF.walk([&](enzyme::SampleOp sampleOp) -> WalkResult { + auto result = NewF.walk([&](impulse::SampleOp sampleOp) -> WalkResult { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sampleOp); @@ -1899,13 +1900,13 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { auto sliceType = RankedTensorType::get( {1, numElements}, resultType.getElementType()); - auto sliced = enzyme::SliceOp::create( + auto sliced = impulse::SliceOp::create( rewriter, sampleOp.getLoc(), sliceType, prevTrace, rewriter.getDenseI64ArrayAttr({0, extractOffset}), rewriter.getDenseI64ArrayAttr( {1, extractOffset + numElements}), rewriter.getDenseI64ArrayAttr({1, 1})); - auto extracted = enzyme::ReshapeOp::create( + auto extracted = impulse::ReshapeOp::create( rewriter, sampleOp.getLoc(), resultType, sliced); sampledValues[i] = extracted.getResult(); extractOffset += numElements; @@ -1959,7 +1960,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { auto flatSampleType = RankedTensorType::get( {1, numElements}, sampleType.getElementType()); - auto flatSample = enzyme::ReshapeOp::create( + auto flatSample = impulse::ReshapeOp::create( rewriter, sampleOp.getLoc(), flatSampleType, sampleValue); auto i64S = RankedTensorType::get({}, rewriter.getI64Type()); auto row0 = arith::ConstantOp::create( @@ -1969,7 +1970,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, sampleOp.getLoc(), i64S, DenseElementsAttr::get( i64S, rewriter.getI64IntegerAttr(currentTraceOffset))); - currTrace = enzyme::DynamicUpdateSliceOp::create( + currTrace = impulse::DynamicUpdateSliceOp::create( rewriter, sampleOp.getLoc(), traceType, currTrace, flatSample, ValueRange{row0, colOff}) .getResult(); @@ -2019,7 +2020,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { auto subTraceType = RankedTensorType::get({1, subPositionSize}, rewriter.getF64Type()); - Value subPrevTrace = enzyme::SliceOp::create( + Value subPrevTrace = impulse::SliceOp::create( rewriter, sampleOp.getLoc(), subTraceType, prevTrace, rewriter.getDenseI64ArrayAttr({0, mergeOffset}), rewriter.getDenseI64ArrayAttr( @@ -2034,7 +2035,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { for (auto t : genFn.getResultTypes()) regenResultTypes.push_back(t); - auto nestedRegenerate = enzyme::RegenerateOp::create( + auto nestedRegenerate = impulse::RegenerateOp::create( rewriter, sampleOp.getLoc(), regenResultTypes, sampleOp.getFnAttr(), sampleOp.getInputs(), subPrevTrace, subSelection, subRegenerateAddrs); @@ -2053,7 +2054,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, sampleOp.getLoc(), i64S, DenseElementsAttr::get( i64S, rewriter.getI64IntegerAttr(mergeOffset))); - currTrace = enzyme::DynamicUpdateSliceOp::create( + currTrace = impulse::DynamicUpdateSliceOp::create( rewriter, sampleOp.getLoc(), traceType, currTrace, subTrace, ValueRange{row0, colOff}) .getResult(); diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index d362e0019b79..28f39ffa1ab8 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -37,6 +37,7 @@ #include "mlir/Transforms/Passes.h" #include "Dialect/Dialect.h" +#include "Dialect/Impulse/Impulse.h" #include "Dialect/LLVMExt/LLVMExt.h" #include "Implementations/CoreDialectsAutoDiffImplementations.h" #include "Passes/Passes.h" @@ -75,6 +76,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); + registry.insert(); registry.insert(); mlir::enzyme::registerenzymePasses(); diff --git a/enzyme/test/MLIR/ProbProg/exp_transform.mlir b/enzyme/test/MLIR/ProbProg/exp_transform.mlir index afd8cb77e69e..c041269547e3 100644 --- a/enzyme/test/MLIR/ProbProg/exp_transform.mlir +++ b/enzyme/test/MLIR/ProbProg/exp_transform.mlir @@ -5,8 +5,8 @@ module { func.func private @logpdf(%x : tensor, %rate : tensor) -> tensor func.func @test(%rng : tensor<2xui64>, %rate : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @exponential(%rng, %rate) { logpdf = @logpdf, symbol = #enzyme.symbol<1>, name="s", support = #enzyme.support } : (tensor<2xui64>, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @exponential(%s#0, %rate) { logpdf = @logpdf, symbol = #enzyme.symbol<2>, name="t", support = #enzyme.support } : (tensor<2xui64>, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @exponential(%rng, %rate) { logpdf = @logpdf, symbol = #impulse.symbol<1>, name="s", support = #impulse.support } : (tensor<2xui64>, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @exponential(%s#0, %rate) { logpdf = @logpdf, symbol = #impulse.symbol<2>, name="t", support = #impulse.support } : (tensor<2xui64>, tensor) -> (tensor<2xui64>, tensor) return %t#0, %t#1 : tensor<2xui64>, tensor } @@ -17,89 +17,89 @@ module { %inverse_mass_matrix = arith.constant dense<[[1.0, 0.0], [0.0, 1.0]]> : tensor<2x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %rate) given %init_trace + %res:8 = impulse.infer @test(%rng, %rate) given %init_trace inverse_mass_matrix = %inverse_mass_matrix step_size = %step_size - { hmc_config = #enzyme.hmc_config, name = "hmc", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], all_addresses = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], num_warmup = 0, num_samples = 1 } + { hmc_config = #impulse.hmc_config, name = "hmc", selection = [[#impulse.symbol<1>], [#impulse.symbol<2>]], all_addresses = [[#impulse.symbol<1>], [#impulse.symbol<2>]], num_warmup = 0, num_samples = 1 } : (tensor<2xui64>, tensor, tensor<1x2xf64>, tensor<2x2xf64>, tensor) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor, tensor, tensor<1x2xf64>) return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64> } } // CHECK: %[[INIT_TRACE:.+]] = arith.constant dense<1.0{{.*}}> : tensor<1x2xf64> -// CHECK: enzyme.dynamic_slice %[[INIT_TRACE]] -// CHECK: enzyme.dynamic_update_slice -// CHECK: enzyme.dynamic_slice %[[INIT_TRACE]] -// CHECK: %[[EXTRACTED_POS:.+]] = enzyme.dynamic_update_slice -// CHECK: %[[FLATTENED:.+]] = enzyme.reshape %[[EXTRACTED_POS]] : (tensor<1x2xf64>) -> tensor<2xf64> +// CHECK: impulse.dynamic_slice %[[INIT_TRACE]] +// CHECK: impulse.dynamic_update_slice +// CHECK: impulse.dynamic_slice %[[INIT_TRACE]] +// CHECK: %[[EXTRACTED_POS:.+]] = impulse.dynamic_update_slice +// CHECK: %[[FLATTENED:.+]] = impulse.reshape %[[EXTRACTED_POS]] : (tensor<1x2xf64>) -> tensor<2xf64> -// CHECK: %[[SAMPLE1:.+]] = enzyme.slice %[[FLATTENED]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[SAMPLE1:.+]] = impulse.slice %[[FLATTENED]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> // CHECK: %[[LOG1:.+]] = math.log %[[SAMPLE1]] : tensor<1xf64> -// CHECK: %[[SAMPLE2:.+]] = enzyme.slice %[[FLATTENED]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[SAMPLE2:.+]] = impulse.slice %[[FLATTENED]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> // CHECK: %[[LOG2:.+]] = math.log %[[SAMPLE2]] : tensor<1xf64> -// CHECK: %[[ELEM1_SLICED:.+]] = enzyme.dynamic_slice %[[LOG1]], %{{.+}} {slice_sizes = array} : (tensor<1xf64>, tensor) -> tensor<1xf64> -// CHECK: %[[ELEM1:.+]] = enzyme.reshape %[[ELEM1_SLICED]] : (tensor<1xf64>) -> tensor -// CHECK: %[[ELEM1_1D:.+]] = enzyme.reshape %[[ELEM1]] : (tensor) -> tensor<1xf64> -// CHECK: %[[POS1:.+]] = enzyme.dynamic_update_slice %{{.+}}, %[[ELEM1_1D]], %{{.+}} : (tensor<2xf64>, tensor<1xf64>, tensor) -> tensor<2xf64> -// CHECK: %[[ELEM2_SLICED:.+]] = enzyme.dynamic_slice %[[LOG2]], %{{.+}} {slice_sizes = array} : (tensor<1xf64>, tensor) -> tensor<1xf64> -// CHECK: %[[ELEM2:.+]] = enzyme.reshape %[[ELEM2_SLICED]] : (tensor<1xf64>) -> tensor -// CHECK: %[[ELEM2_1D:.+]] = enzyme.reshape %[[ELEM2]] : (tensor) -> tensor<1xf64> -// CHECK: %[[UNCONSTRAINED_POS:.+]] = enzyme.dynamic_update_slice %[[POS1]], %[[ELEM2_1D]], %{{.+}} : (tensor<2xf64>, tensor<1xf64>, tensor) -> tensor<2xf64> +// CHECK: %[[ELEM1_SLICED:.+]] = impulse.dynamic_slice %[[LOG1]], %{{.+}} {slice_sizes = array} : (tensor<1xf64>, tensor) -> tensor<1xf64> +// CHECK: %[[ELEM1:.+]] = impulse.reshape %[[ELEM1_SLICED]] : (tensor<1xf64>) -> tensor +// CHECK: %[[ELEM1_1D:.+]] = impulse.reshape %[[ELEM1]] : (tensor) -> tensor<1xf64> +// CHECK: %[[POS1:.+]] = impulse.dynamic_update_slice %{{.+}}, %[[ELEM1_1D]], %{{.+}} : (tensor<2xf64>, tensor<1xf64>, tensor) -> tensor<2xf64> +// CHECK: %[[ELEM2_SLICED:.+]] = impulse.dynamic_slice %[[LOG2]], %{{.+}} {slice_sizes = array} : (tensor<1xf64>, tensor) -> tensor<1xf64> +// CHECK: %[[ELEM2:.+]] = impulse.reshape %[[ELEM2_SLICED]] : (tensor<1xf64>) -> tensor +// CHECK: %[[ELEM2_1D:.+]] = impulse.reshape %[[ELEM2]] : (tensor) -> tensor<1xf64> +// CHECK: %[[UNCONSTRAINED_POS:.+]] = impulse.dynamic_update_slice %[[POS1]], %[[ELEM2_1D]], %{{.+}} : (tensor<2xf64>, tensor<1xf64>, tensor) -> tensor<2xf64> -// CHECK: %[[UNCONSTRAINED_POS_2D:.+]] = enzyme.reshape %[[UNCONSTRAINED_POS]] : (tensor<2xf64>) -> tensor<1x2xf64> +// CHECK: %[[UNCONSTRAINED_POS_2D:.+]] = impulse.reshape %[[UNCONSTRAINED_POS]] : (tensor<2xf64>) -> tensor<1x2xf64> // // --- Constrain position for initial generate --- -// CHECK: enzyme.dynamic_slice -// CHECK: enzyme.dynamic_update_slice -// CHECK: enzyme.dynamic_slice -// CHECK: enzyme.dynamic_update_slice +// CHECK: impulse.dynamic_slice +// CHECK: impulse.dynamic_update_slice +// CHECK: impulse.dynamic_slice +// CHECK: impulse.dynamic_update_slice // CHECK: %[[INIT_GEN:.+]]:4 = call @test.generate{{.*}}(%{{.+}}, %{{.+}}, %{{.+}}) // CHECK: %[[NEG_WEIGHT:.+]] = arith.negf %[[INIT_GEN]]#1 : tensor -// CHECK: %[[JAC_POS:.+]] = enzyme.reshape %[[UNCONSTRAINED_POS_2D]] : (tensor<1x2xf64>) -> tensor<2xf64> -// CHECK: %[[Z1:.+]] = enzyme.slice %[[JAC_POS]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> -// CHECK: %[[SUM1:.+]] = enzyme.dot %[[Z1]], %{{.+}} {{.*}} : (tensor<1xf64>, tensor<1xf64>) -> tensor +// CHECK: %[[JAC_POS:.+]] = impulse.reshape %[[UNCONSTRAINED_POS_2D]] : (tensor<1x2xf64>) -> tensor<2xf64> +// CHECK: %[[Z1:.+]] = impulse.slice %[[JAC_POS]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[SUM1:.+]] = impulse.dot %[[Z1]], %{{.+}} {{.*}} : (tensor<1xf64>, tensor<1xf64>) -> tensor // CHECK: %[[PARTIAL_SUM:.+]] = arith.addf %[[SUM1]], %{{.+}} : tensor -// CHECK: %[[Z2:.+]] = enzyme.slice %[[JAC_POS]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> -// CHECK: %[[SUM2:.+]] = enzyme.dot %[[Z2]], %{{.+}} {{.*}} : (tensor<1xf64>, tensor<1xf64>) -> tensor +// CHECK: %[[Z2:.+]] = impulse.slice %[[JAC_POS]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[SUM2:.+]] = impulse.dot %[[Z2]], %{{.+}} {{.*}} : (tensor<1xf64>, tensor<1xf64>) -> tensor // CHECK: %[[TOTAL_JACOBIAN:.+]] = arith.addf %[[PARTIAL_SUM]], %[[SUM2]] : tensor // CHECK: %[[U0:.+]] = arith.subf %[[NEG_WEIGHT]], %[[TOTAL_JACOBIAN]] : tensor // CHECK: %[[AUTODIFF:.+]] = enzyme.autodiff_region(%[[UNCONSTRAINED_POS_2D]], %{{.+}}) { // CHECK: ^bb0(%[[ARG:.+]]: tensor<1x2xf64>): -// CHECK: %[[ARG_1D:.+]] = enzyme.reshape %[[ARG]] : (tensor<1x2xf64>) -> tensor<2xf64> -// CHECK: %[[Z_SAMPLE1:.+]] = enzyme.slice %[[ARG_1D]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[ARG_1D:.+]] = impulse.reshape %[[ARG]] : (tensor<1x2xf64>) -> tensor<2xf64> +// CHECK: %[[Z_SAMPLE1:.+]] = impulse.slice %[[ARG_1D]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> // CHECK: %[[EXP1:.+]] = math.exp %[[Z_SAMPLE1]] : tensor<1xf64> -// CHECK: %[[Z_SAMPLE2:.+]] = enzyme.slice %[[ARG_1D]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[Z_SAMPLE2:.+]] = impulse.slice %[[ARG_1D]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> // CHECK: %[[EXP2:.+]] = math.exp %[[Z_SAMPLE2]] : tensor<1xf64> -// CHECK: %[[C_ELEM1_S:.+]] = enzyme.dynamic_slice %[[EXP1]], %{{.+}} {slice_sizes = array} : (tensor<1xf64>, tensor) -> tensor<1xf64> -// CHECK: %[[C_ELEM1:.+]] = enzyme.reshape %[[C_ELEM1_S]] : (tensor<1xf64>) -> tensor -// CHECK: %[[C_ELEM1_1D:.+]] = enzyme.reshape %[[C_ELEM1]] : (tensor) -> tensor<1xf64> -// CHECK: %[[C_POS1:.+]] = enzyme.dynamic_update_slice %{{.+}}, %[[C_ELEM1_1D]], %{{.+}} : (tensor<2xf64>, tensor<1xf64>, tensor) -> tensor<2xf64> -// CHECK: %[[C_ELEM2_S:.+]] = enzyme.dynamic_slice %[[EXP2]], %{{.+}} {slice_sizes = array} : (tensor<1xf64>, tensor) -> tensor<1xf64> -// CHECK: %[[C_ELEM2:.+]] = enzyme.reshape %[[C_ELEM2_S]] : (tensor<1xf64>) -> tensor -// CHECK: %[[C_ELEM2_1D:.+]] = enzyme.reshape %[[C_ELEM2]] : (tensor) -> tensor<1xf64> -// CHECK: %[[CONSTRAINED_POS:.+]] = enzyme.dynamic_update_slice %[[C_POS1]], %[[C_ELEM2_1D]], %{{.+}} : (tensor<2xf64>, tensor<1xf64>, tensor) -> tensor<2xf64> +// CHECK: %[[C_ELEM1_S:.+]] = impulse.dynamic_slice %[[EXP1]], %{{.+}} {slice_sizes = array} : (tensor<1xf64>, tensor) -> tensor<1xf64> +// CHECK: %[[C_ELEM1:.+]] = impulse.reshape %[[C_ELEM1_S]] : (tensor<1xf64>) -> tensor +// CHECK: %[[C_ELEM1_1D:.+]] = impulse.reshape %[[C_ELEM1]] : (tensor) -> tensor<1xf64> +// CHECK: %[[C_POS1:.+]] = impulse.dynamic_update_slice %{{.+}}, %[[C_ELEM1_1D]], %{{.+}} : (tensor<2xf64>, tensor<1xf64>, tensor) -> tensor<2xf64> +// CHECK: %[[C_ELEM2_S:.+]] = impulse.dynamic_slice %[[EXP2]], %{{.+}} {slice_sizes = array} : (tensor<1xf64>, tensor) -> tensor<1xf64> +// CHECK: %[[C_ELEM2:.+]] = impulse.reshape %[[C_ELEM2_S]] : (tensor<1xf64>) -> tensor +// CHECK: %[[C_ELEM2_1D:.+]] = impulse.reshape %[[C_ELEM2]] : (tensor) -> tensor<1xf64> +// CHECK: %[[CONSTRAINED_POS:.+]] = impulse.dynamic_update_slice %[[C_POS1]], %[[C_ELEM2_1D]], %{{.+}} : (tensor<2xf64>, tensor<1xf64>, tensor) -> tensor<2xf64> -// CHECK: enzyme.reshape %[[CONSTRAINED_POS]] +// CHECK: impulse.reshape %[[CONSTRAINED_POS]] // CHECK: %[[GEN_RES:.+]]:4 = func.call @test.generate{{.*}} // CHECK: %[[NEG_GEN_WEIGHT:.+]] = arith.negf %[[GEN_RES]]#1 : tensor -// CHECK: %[[ARG_1D_2:.+]] = enzyme.reshape %[[ARG]] : (tensor<1x2xf64>) -> tensor<2xf64> -// CHECK: %[[ARG_Z1:.+]] = enzyme.slice %[[ARG_1D_2]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> -// CHECK: %[[ARG_DOT1:.+]] = enzyme.dot %[[ARG_Z1]], %{{.+}} +// CHECK: %[[ARG_1D_2:.+]] = impulse.reshape %[[ARG]] : (tensor<1x2xf64>) -> tensor<2xf64> +// CHECK: %[[ARG_Z1:.+]] = impulse.slice %[[ARG_1D_2]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[ARG_DOT1:.+]] = impulse.dot %[[ARG_Z1]], %{{.+}} // CHECK: %[[ARG_PARTIAL:.+]] = arith.addf %[[ARG_DOT1]], %{{.+}} : tensor -// CHECK: %[[ARG_Z2:.+]] = enzyme.slice %[[ARG_1D_2]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> -// CHECK: %[[ARG_DOT2:.+]] = enzyme.dot %[[ARG_Z2]], %{{.+}} +// CHECK: %[[ARG_Z2:.+]] = impulse.slice %[[ARG_1D_2]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[ARG_DOT2:.+]] = impulse.dot %[[ARG_Z2]], %{{.+}} // CHECK: %[[JAC_SUM:.+]] = arith.addf %[[ARG_PARTIAL]], %[[ARG_DOT2]] : tensor // CHECK: %[[ADJUSTED_U:.+]] = arith.subf %[[NEG_GEN_WEIGHT]], %[[JAC_SUM]] : tensor // CHECK: enzyme.yield %[[ADJUSTED_U]], %{{.+}} : tensor, tensor<2xui64> // CHECK: } -// CHECK: enzyme.for_loop +// CHECK: impulse.for // CHECK: enzyme.autodiff_region // CHECK: math.exp @@ -110,23 +110,23 @@ module { // CHECK: math.exp // CHECK: arith.minimumf -// CHECK: enzyme.random +// CHECK: impulse.random // CHECK: arith.cmpf olt -// CHECK: %[[FINAL_SELECT:.+]] = enzyme.select -// CHECK: %[[FINAL_1D:.+]] = enzyme.reshape %[[FINAL_SELECT]] : (tensor<1x2xf64>) -> tensor<2xf64> -// CHECK: %[[FINAL_SAMPLE1:.+]] = enzyme.slice %[[FINAL_1D]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[FINAL_SELECT:.+]] = impulse.select +// CHECK: %[[FINAL_1D:.+]] = impulse.reshape %[[FINAL_SELECT]] : (tensor<1x2xf64>) -> tensor<2xf64> +// CHECK: %[[FINAL_SAMPLE1:.+]] = impulse.slice %[[FINAL_1D]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> // CHECK: %[[FINAL_EXP1:.+]] = math.exp %[[FINAL_SAMPLE1]] : tensor<1xf64> -// CHECK: %[[FINAL_SAMPLE2:.+]] = enzyme.slice %[[FINAL_1D]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> +// CHECK: %[[FINAL_SAMPLE2:.+]] = impulse.slice %[[FINAL_1D]] {limit_indices = array, start_indices = array, strides = array} : (tensor<2xf64>) -> tensor<1xf64> // CHECK: %[[FINAL_EXP2:.+]] = math.exp %[[FINAL_SAMPLE2]] : tensor<1xf64> -// CHECK: enzyme.dynamic_update_slice +// CHECK: impulse.dynamic_update_slice // CHECK: return // CHECK-LABEL: func.func @test.generate -// CHECK: enzyme.slice %{{.+}} {limit_indices = array, start_indices = array, strides = array} -// CHECK: enzyme.reshape +// CHECK: impulse.slice %{{.+}} {limit_indices = array, start_indices = array, strides = array} +// CHECK: impulse.reshape // CHECK: call @logpdf -// CHECK: enzyme.slice %{{.+}} {limit_indices = array, start_indices = array, strides = array} -// CHECK: enzyme.reshape +// CHECK: impulse.slice %{{.+}} {limit_indices = array, start_indices = array, strides = array} +// CHECK: impulse.reshape // CHECK: call @logpdf // CHECK: return diff --git a/enzyme/test/MLIR/ProbProg/generate.mlir b/enzyme/test/MLIR/ProbProg/generate.mlir index d158af65eb99..f5c7d6c5c3d5 100644 --- a/enzyme/test/MLIR/ProbProg/generate.mlir +++ b/enzyme/test/MLIR/ProbProg/generate.mlir @@ -6,15 +6,15 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @model(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @normal(%s#0, %s#1, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @normal(%s#0, %s#1, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %t#0, %t#1 : tensor<2xui64>, tensor } func.func @test_base(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor) { %constraint = arith.constant dense<[[1.5]]> : tensor<1x1xf64> - %res:4 = enzyme.generate @model(%rng, %mean, %stddev) given %constraint - { selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], constrained_addresses = [[#enzyme.symbol<2>]] } + %res:4 = impulse.generate @model(%rng, %mean, %stddev) given %constraint + { selection = [[#impulse.symbol<1>], [#impulse.symbol<2>]], constrained_addresses = [[#impulse.symbol<2>]] } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor) return %res#0, %res#1, %res#2, %res#3 : tensor<1x2xf64>, tensor, tensor<2xui64>, tensor } @@ -35,14 +35,14 @@ module { // BASE: %[[S1:.+]]:2 = call @normal(%[[ARG1]], %[[ARG2]], %[[ARG3]]) // BASE-NEXT: %[[LP1:.+]] = call @logpdf(%[[S1]]#1, %[[ARG2]], %[[ARG3]]) // BASE-NEXT: %[[W1:.+]] = arith.addf %[[LP1]], %[[ZERO]] : tensor -// BASE-NEXT: %[[RS1:.+]] = enzyme.reshape %[[S1]]#1 : (tensor) -> tensor<1x1xf64> -// BASE-NEXT: %[[TR1:.+]] = enzyme.dynamic_update_slice %[[TRACE_INIT]], %[[RS1]], %[[C0]], %[[C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> -// BASE-NEXT: %[[SLICED:.+]] = enzyme.slice %[[ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x1xf64>) -> tensor<1x1xf64> -// BASE-NEXT: %[[CONSTRAINED:.+]] = enzyme.reshape %[[SLICED]] : (tensor<1x1xf64>) -> tensor +// BASE-NEXT: %[[RS1:.+]] = impulse.reshape %[[S1]]#1 : (tensor) -> tensor<1x1xf64> +// BASE-NEXT: %[[TR1:.+]] = impulse.dynamic_update_slice %[[TRACE_INIT]], %[[RS1]], %[[C0]], %[[C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// BASE-NEXT: %[[SLICED:.+]] = impulse.slice %[[ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x1xf64>) -> tensor<1x1xf64> +// BASE-NEXT: %[[CONSTRAINED:.+]] = impulse.reshape %[[SLICED]] : (tensor<1x1xf64>) -> tensor // BASE-NEXT: %[[LP2:.+]] = call @logpdf(%[[CONSTRAINED]], %[[S1]]#1, %[[ARG3]]) // BASE-NEXT: %[[W2:.+]] = arith.addf %[[W1]], %[[LP2]] : tensor -// BASE-NEXT: %[[RS2:.+]] = enzyme.reshape %[[CONSTRAINED]] : (tensor) -> tensor<1x1xf64> -// BASE-NEXT: %[[TR2:.+]] = enzyme.dynamic_update_slice %[[TR1]], %[[RS2]], %[[C0]], %[[C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// BASE-NEXT: %[[RS2:.+]] = impulse.reshape %[[CONSTRAINED]] : (tensor) -> tensor<1x1xf64> +// BASE-NEXT: %[[TR2:.+]] = impulse.dynamic_update_slice %[[TR1]], %[[RS2]], %[[C0]], %[[C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // BASE-NEXT: return %[[TR2]], %[[W2]], %[[S1]]#0, %[[CONSTRAINED]] // ----- @@ -52,22 +52,22 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @inner(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<3> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @normal(%s#0, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<4> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<3> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @normal(%s#0, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<4> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %t#0, %s#1, %t#1 : tensor<2xui64>, tensor, tensor } func.func @outer(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:3 = enzyme.sample @inner(%s#0, %s#1, %stddev) { symbol = #enzyme.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:3 = impulse.sample @inner(%s#0, %s#1, %stddev) { symbol = #impulse.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor, tensor) return %t#0, %t#1, %t#2 : tensor<2xui64>, tensor, tensor } func.func @test_hier(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x3xf64>, tensor, tensor<2xui64>, tensor, tensor) { %constraint = arith.constant dense<[[1.0, 2.0]]> : tensor<1x2xf64> - %res:5 = enzyme.generate @outer(%rng, %mean, %stddev) given %constraint - { selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>, #enzyme.symbol<3>], [#enzyme.symbol<2>, #enzyme.symbol<4>]], - constrained_addresses = [[#enzyme.symbol<1>], [#enzyme.symbol<2>, #enzyme.symbol<3>]] } + %res:5 = impulse.generate @outer(%rng, %mean, %stddev) given %constraint + { selection = [[#impulse.symbol<1>], [#impulse.symbol<2>, #impulse.symbol<3>], [#impulse.symbol<2>, #impulse.symbol<4>]], + constrained_addresses = [[#impulse.symbol<1>], [#impulse.symbol<2>, #impulse.symbol<3>]] } : (tensor<2xui64>, tensor, tensor, tensor<1x2xf64>) -> (tensor<1x3xf64>, tensor, tensor<2xui64>, tensor, tensor) return %res#0, %res#1, %res#2, %res#3, %res#4 : tensor<1x3xf64>, tensor, tensor<2xui64>, tensor, tensor } @@ -85,16 +85,16 @@ module { // HIER-DAG: %[[O_C0:.+]] = arith.constant dense<0> : tensor // HIER-DAG: %[[O_ZERO:.+]] = arith.constant dense<0.000000e+00> : tensor // HIER-DAG: %[[O_TRACE_INIT:.+]] = arith.constant dense<0.000000e+00> : tensor<1x3xf64> -// HIER: %[[O_SLICED1:.+]] = enzyme.slice %[[O_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> -// HIER-NEXT: %[[O_CONSTRAINED:.+]] = enzyme.reshape %[[O_SLICED1]] : (tensor<1x1xf64>) -> tensor +// HIER: %[[O_SLICED1:.+]] = impulse.slice %[[O_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> +// HIER-NEXT: %[[O_CONSTRAINED:.+]] = impulse.reshape %[[O_SLICED1]] : (tensor<1x1xf64>) -> tensor // HIER-NEXT: %[[O_LP1:.+]] = call @logpdf(%[[O_CONSTRAINED]], %[[O_ARG2]], %[[O_ARG3]]) // HIER-NEXT: %[[O_W1:.+]] = arith.addf %[[O_LP1]], %[[O_ZERO]] : tensor -// HIER-NEXT: %[[O_RS1:.+]] = enzyme.reshape %[[O_CONSTRAINED]] : (tensor) -> tensor<1x1xf64> -// HIER-NEXT: %[[O_TR1:.+]] = enzyme.dynamic_update_slice %[[O_TRACE_INIT]], %[[O_RS1]], %[[O_C0]], %[[O_C0]] : (tensor<1x3xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x3xf64> -// HIER-NEXT: %[[O_SUB_CONSTRAINT:.+]] = enzyme.slice %[[O_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> +// HIER-NEXT: %[[O_RS1:.+]] = impulse.reshape %[[O_CONSTRAINED]] : (tensor) -> tensor<1x1xf64> +// HIER-NEXT: %[[O_TR1:.+]] = impulse.dynamic_update_slice %[[O_TRACE_INIT]], %[[O_RS1]], %[[O_C0]], %[[O_C0]] : (tensor<1x3xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x3xf64> +// HIER-NEXT: %[[O_SUB_CONSTRAINT:.+]] = impulse.slice %[[O_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> // HIER-NEXT: %[[O_NESTED:.+]]:5 = call @inner.generate(%[[O_SUB_CONSTRAINT]], %[[O_ARG1]], %[[O_CONSTRAINED]], %[[O_ARG3]]) : (tensor<1x1xf64>, tensor<2xui64>, tensor, tensor) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor, tensor) // HIER-NEXT: %[[O_W2:.+]] = arith.addf %[[O_W1]], %[[O_NESTED]]#1 : tensor -// HIER-NEXT: %[[O_TR2:.+]] = enzyme.dynamic_update_slice %[[O_TR1]], %[[O_NESTED]]#0, %[[O_C0]], %[[O_C1]] : (tensor<1x3xf64>, tensor<1x2xf64>, tensor, tensor) -> tensor<1x3xf64> +// HIER-NEXT: %[[O_TR2:.+]] = impulse.dynamic_update_slice %[[O_TR1]], %[[O_NESTED]]#0, %[[O_C0]], %[[O_C1]] : (tensor<1x3xf64>, tensor<1x2xf64>, tensor, tensor) -> tensor<1x3xf64> // HIER-NEXT: return %[[O_TR2]], %[[O_W2]], %[[O_NESTED]]#2, %[[O_NESTED]]#3, %[[O_NESTED]]#4 // HIER-LABEL: func.func @inner.generate @@ -103,15 +103,15 @@ module { // HIER-DAG: %[[I_C0:.+]] = arith.constant dense<0> : tensor // HIER-DAG: %[[I_ZERO:.+]] = arith.constant dense<0.000000e+00> : tensor // HIER-DAG: %[[I_TRACE_INIT:.+]] = arith.constant dense<0.000000e+00> : tensor<1x2xf64> -// HIER: %[[I_SLICED:.+]] = enzyme.slice %[[I_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x1xf64>) -> tensor<1x1xf64> -// HIER-NEXT: %[[I_CONSTRAINED:.+]] = enzyme.reshape %[[I_SLICED]] : (tensor<1x1xf64>) -> tensor +// HIER: %[[I_SLICED:.+]] = impulse.slice %[[I_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x1xf64>) -> tensor<1x1xf64> +// HIER-NEXT: %[[I_CONSTRAINED:.+]] = impulse.reshape %[[I_SLICED]] : (tensor<1x1xf64>) -> tensor // HIER-NEXT: %[[I_LP1:.+]] = call @logpdf(%[[I_CONSTRAINED]], %[[I_ARG2]], %[[I_ARG3]]) // HIER-NEXT: %[[I_W1:.+]] = arith.addf %[[I_LP1]], %[[I_ZERO]] : tensor -// HIER-NEXT: %[[I_RS1:.+]] = enzyme.reshape %[[I_CONSTRAINED]] : (tensor) -> tensor<1x1xf64> -// HIER-NEXT: %[[I_TR1:.+]] = enzyme.dynamic_update_slice %[[I_TRACE_INIT]], %[[I_RS1]], %[[I_C0]], %[[I_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// HIER-NEXT: %[[I_RS1:.+]] = impulse.reshape %[[I_CONSTRAINED]] : (tensor) -> tensor<1x1xf64> +// HIER-NEXT: %[[I_TR1:.+]] = impulse.dynamic_update_slice %[[I_TRACE_INIT]], %[[I_RS1]], %[[I_C0]], %[[I_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // HIER-NEXT: %[[I_S2:.+]]:2 = call @normal(%[[I_ARG1]], %[[I_ARG2]], %[[I_ARG3]]) // HIER-NEXT: %[[I_LP2:.+]] = call @logpdf(%[[I_S2]]#1, %[[I_ARG2]], %[[I_ARG3]]) // HIER-NEXT: %[[I_W2:.+]] = arith.addf %[[I_W1]], %[[I_LP2]] : tensor -// HIER-NEXT: %[[I_RS2:.+]] = enzyme.reshape %[[I_S2]]#1 : (tensor) -> tensor<1x1xf64> -// HIER-NEXT: %[[I_TR2:.+]] = enzyme.dynamic_update_slice %[[I_TR1]], %[[I_RS2]], %[[I_C0]], %[[I_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// HIER-NEXT: %[[I_RS2:.+]] = impulse.reshape %[[I_S2]]#1 : (tensor) -> tensor<1x1xf64> +// HIER-NEXT: %[[I_TR2:.+]] = impulse.dynamic_update_slice %[[I_TR1]], %[[I_RS2]], %[[I_C0]], %[[I_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // HIER-NEXT: return %[[I_TR2]], %[[I_W2]], %[[I_S2]]#0, %[[I_CONSTRAINED]], %[[I_S2]]#1 diff --git a/enzyme/test/MLIR/ProbProg/hmc_diag_mass.mlir b/enzyme/test/MLIR/ProbProg/hmc_diag_mass.mlir index 20e6b130ddbf..46eabe0797b0 100644 --- a/enzyme/test/MLIR/ProbProg/hmc_diag_mass.mlir +++ b/enzyme/test/MLIR/ProbProg/hmc_diag_mass.mlir @@ -5,8 +5,8 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @test(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @normal(%s#0, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<2>, name="t" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @normal(%s#0, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<2>, name="t" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %t#0, %s#1, %t#1 : tensor<2xui64>, tensor, tensor } @@ -14,11 +14,11 @@ module { %init_trace = arith.constant dense<[[0.0, 0.0]]> : tensor<1x2xf64> %inv_mass = arith.constant dense<[2.0, 3.0]> : tensor<2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace inverse_mass_matrix = %inv_mass step_size = %step_size - { hmc_config = #enzyme.hmc_config, - name = "hmc_diag", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], all_addresses = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], num_warmup = 0, num_samples = 1 } + { hmc_config = #impulse.hmc_config, + name = "hmc_diag", selection = [[#impulse.symbol<1>], [#impulse.symbol<2>]], all_addresses = [[#impulse.symbol<1>], [#impulse.symbol<2>]], num_warmup = 0, num_samples = 1 } : (tensor<2xui64>, tensor, tensor, tensor<1x2xf64>, tensor<2xf64>, tensor) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor, tensor, tensor<1x2xf64>) return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64> } @@ -27,11 +27,11 @@ module { %init_trace = arith.constant dense<[[0.0, 0.0]]> : tensor<1x2xf64> %inv_mass = arith.constant dense<[2.0, 3.0]> : tensor<2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace inverse_mass_matrix = %inv_mass step_size = %step_size - { nuts_config = #enzyme.nuts_config, - name = "nuts_diag", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], all_addresses = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], num_warmup = 0, num_samples = 1 } + { nuts_config = #impulse.nuts_config, + name = "nuts_diag", selection = [[#impulse.symbol<1>], [#impulse.symbol<2>]], all_addresses = [[#impulse.symbol<1>], [#impulse.symbol<2>]], num_warmup = 0, num_samples = 1 } : (tensor<2xui64>, tensor, tensor, tensor<1x2xf64>, tensor<2xf64>, tensor) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor, tensor, tensor<1x2xf64>) return %res#0, %res#1, %res#2 : tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64> } @@ -44,24 +44,24 @@ module { // CHECK-DAG: %[[HALF:.+]] = arith.constant dense<5.000000e-01> : tensor // CHECK-DAG: %[[ZERO_F:.+]] = arith.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ONE:.+]] = arith.constant dense<1.000000e+00> : tensor -// CHECK: enzyme.for_loop +// CHECK: impulse.for // CHECK: ^bb0( -// CHECK: %[[EPS_RNG:.+]], %[[EPS:.+]] = enzyme.random {{.*}} {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x2xf64>) -// CHECK-NEXT: %[[SQRT_2D:.+]] = enzyme.reshape %[[MASS_SQRT]] : (tensor<2xf64>) -> tensor<1x2xf64> +// CHECK: %[[EPS_RNG:.+]], %[[EPS:.+]] = impulse.random {{.*}} {rng_distribution = #impulse} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x2xf64>) +// CHECK-NEXT: %[[SQRT_2D:.+]] = impulse.reshape %[[MASS_SQRT]] : (tensor<2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: %[[P:.+]] = arith.mulf %[[SQRT_2D]], %[[EPS]] : tensor<1x2xf64> -// CHECK-NEXT: %[[INV_2D:.+]] = enzyme.reshape %[[INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> +// CHECK-NEXT: %[[INV_2D:.+]] = impulse.reshape %[[INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: %[[V:.+]] = arith.mulf %[[INV_2D]], %[[P]] : tensor<1x2xf64> -// CHECK-NEXT: %[[KE_DOT:.+]] = enzyme.dot %[[P]], %[[V]] {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor +// CHECK-NEXT: %[[KE_DOT:.+]] = impulse.dot %[[P]], %[[V]] {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor // CHECK-NEXT: %[[KE:.+]] = arith.mulf %[[KE_DOT]], %[[HALF]] : tensor -// CHECK: enzyme.for_loop +// CHECK: impulse.for // CHECK: ^bb0(%[[LF_I:.+]]: tensor, %[[LF_Q:.+]]: tensor<1x2xf64>, %[[LF_P:.+]]: tensor<1x2xf64>, %[[LF_G:.+]]: tensor<1x2xf64>, %[[LF_U:.+]]: tensor, %[[LF_RNG:.+]]: tensor<2xui64>): -// CHECK: %[[DIR:.+]] = enzyme.select {{.*}} : (tensor, tensor, tensor) -> tensor +// CHECK: %[[DIR:.+]] = impulse.select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[DIR_BC:.+]] = "enzyme.broadcast"(%[[DIR]]) <{shape = array}> : (tensor) -> tensor<1x2xf64> // CHECK-NEXT: %[[HALF_STEP:.+]] = arith.mulf %[[DIR]], %[[HALF]] : tensor // CHECK-NEXT: %[[HALF_STEP_BC:.+]] = "enzyme.broadcast"(%[[HALF_STEP]]) <{shape = array}> : (tensor) -> tensor<1x2xf64> // CHECK-NEXT: %[[GRAD_SCALED:.+]] = arith.mulf %[[HALF_STEP_BC]], %[[LF_G]] : tensor<1x2xf64> // CHECK-NEXT: %[[P_HALF:.+]] = arith.subf %[[LF_P]], %[[GRAD_SCALED]] : tensor<1x2xf64> -// CHECK-NEXT: %[[INV_2D_LF:.+]] = enzyme.reshape %[[INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> +// CHECK-NEXT: %[[INV_2D_LF:.+]] = impulse.reshape %[[INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: %[[V_LF:.+]] = arith.mulf %[[INV_2D_LF]], %[[P_HALF]] : tensor<1x2xf64> // CHECK-NEXT: %[[DELTA_Q:.+]] = arith.mulf %[[DIR_BC]], %[[V_LF]] : tensor<1x2xf64> // CHECK-NEXT: %[[Q_NEW:.+]] = arith.addf %[[LF_Q]], %[[DELTA_Q]] : tensor<1x2xf64> @@ -71,38 +71,38 @@ module { // CHECK: } // CHECK: arith.mulf {{.*}} : tensor<1x2xf64> // CHECK-NEXT: arith.subf {{.*}} : tensor<1x2xf64> -// CHECK: enzyme.yield {{.*}} : tensor<1x2xf64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor, tensor<2xui64> +// CHECK: impulse.yield {{.*}} : tensor<1x2xf64>, tensor<1x2xf64>, tensor<1x2xf64>, tensor, tensor<2xui64> // CHECK-NEXT: } -// CHECK-NEXT: %[[INV_2D_FINAL:.+]] = enzyme.reshape %[[INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> +// CHECK-NEXT: %[[INV_2D_FINAL:.+]] = impulse.reshape %[[INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: %[[V_FINAL:.+]] = arith.mulf %[[INV_2D_FINAL]], {{.*}} : tensor<1x2xf64> -// CHECK-NEXT: %[[KE_FINAL_DOT:.+]] = enzyme.dot {{.*}}, %[[V_FINAL]] {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor +// CHECK-NEXT: %[[KE_FINAL_DOT:.+]] = impulse.dot {{.*}}, %[[V_FINAL]] {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x2xf64>, tensor<1x2xf64>) -> tensor // CHECK-NEXT: %[[KE_FINAL:.+]] = arith.mulf %[[KE_FINAL_DOT]], %[[HALF]] : tensor // CHECK-NEXT: arith.addf {{.*}} : tensor // CHECK-NEXT: arith.subf {{.*}} : tensor // CHECK-NEXT: math.exp {{.*}} : tensor // CHECK-NEXT: arith.minimumf {{.*}} : tensor -// CHECK-NEXT: enzyme.random {{.*}} {rng_distribution = #enzyme} +// CHECK-NEXT: impulse.random {{.*}} {rng_distribution = #impulse} // CHECK-NEXT: arith.cmpf olt, {{.*}} : tensor -// CHECK-NEXT: enzyme.select {{.*}} : (tensor, tensor<1x2xf64>, tensor<1x2xf64>) -// CHECK-NEXT: enzyme.select {{.*}} : (tensor, tensor<1x2xf64>, tensor<1x2xf64>) -// CHECK-NEXT: enzyme.select {{.*}} : (tensor, tensor, tensor) +// CHECK-NEXT: impulse.select {{.*}} : (tensor, tensor<1x2xf64>, tensor<1x2xf64>) +// CHECK-NEXT: impulse.select {{.*}} : (tensor, tensor<1x2xf64>, tensor<1x2xf64>) +// CHECK-NEXT: impulse.select {{.*}} : (tensor, tensor, tensor) // CHECK-LABEL: func.func @nuts_diag_mass // CHECK-SAME: (%{{.+}}: tensor<2xui64>, %{{.+}}: tensor, %{{.+}}: tensor) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) // CHECK-DAG: %[[N_INV_MASS:.+]] = arith.constant dense<[2.000000e+00, 3.000000e+00]> : tensor<2xf64> // CHECK-DAG: %[[N_MASS_SQRT:.+]] = arith.constant dense<[0.70710678118654746, 0.57735026918962584]> : tensor<2xf64> -// CHECK: enzyme.for_loop +// CHECK: impulse.for // CHECK: ^bb0( -// CHECK: enzyme.random {{.*}} {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x2xf64>) -// CHECK-NEXT: %[[N_SQRT_2D:.+]] = enzyme.reshape %[[N_MASS_SQRT]] : (tensor<2xf64>) -> tensor<1x2xf64> +// CHECK: impulse.random {{.*}} {rng_distribution = #impulse} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x2xf64>) +// CHECK-NEXT: %[[N_SQRT_2D:.+]] = impulse.reshape %[[N_MASS_SQRT]] : (tensor<2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: %[[N_P:.+]] = arith.mulf %[[N_SQRT_2D]], {{.*}} : tensor<1x2xf64> -// CHECK-NEXT: %[[N_INV_2D:.+]] = enzyme.reshape %[[N_INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> +// CHECK-NEXT: %[[N_INV_2D:.+]] = impulse.reshape %[[N_INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: arith.mulf %[[N_INV_2D]], %[[N_P]] : tensor<1x2xf64> -// CHECK-NEXT: enzyme.dot {{.*}} lhs_contracting_dimensions = array -// CHECK: enzyme.while_loop -// CHECK: enzyme.while_loop -// CHECK: enzyme.reshape %[[N_INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> +// CHECK-NEXT: impulse.dot {{.*}} lhs_contracting_dimensions = array +// CHECK: impulse.while +// CHECK: impulse.while +// CHECK: impulse.reshape %[[N_INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: arith.mulf {{.*}} : tensor<1x2xf64> -// CHECK: enzyme.reshape %[[N_INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> +// CHECK: impulse.reshape %[[N_INV_MASS]] : (tensor<2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: arith.mulf {{.*}} : tensor<1x2xf64> -// CHECK-NEXT: enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK-NEXT: impulse.dot {{.*}} lhs_contracting_dimensions = array diff --git a/enzyme/test/MLIR/ProbProg/hmc_kernel.mlir b/enzyme/test/MLIR/ProbProg/hmc_kernel.mlir index d0aa862c92e7..44a442be9d13 100644 --- a/enzyme/test/MLIR/ProbProg/hmc_kernel.mlir +++ b/enzyme/test/MLIR/ProbProg/hmc_kernel.mlir @@ -5,17 +5,17 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @test(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %s#0, %s#1 : tensor<2xui64>, tensor } func.func @hmc(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>) { %init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace step_size = %step_size - { hmc_config = #enzyme.hmc_config, - name = "hmc", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 0, num_samples = 10 } + { hmc_config = #impulse.hmc_config, + name = "hmc", selection = [[#impulse.symbol<1>]], all_addresses = [[#impulse.symbol<1>]], num_warmup = 0, num_samples = 10 } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>, tensor) -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor<1x1xf64>) return %res#0, %res#1, %res#2 : tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64> } @@ -37,55 +37,55 @@ module { // CHECK-DAG: %[[INIT_TRACE:.+]] = arith.constant dense<0.000000e+00> : tensor<1x1xf64> // // --- RNG splits --- -// CHECK: %[[SPLIT1:.+]]:2 = enzyme.randomSplit %[[RNG]] : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>) -// CHECK-NEXT: %[[SPLIT2:.+]]:3 = enzyme.randomSplit %[[SPLIT1]]#0 : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>, tensor<2xui64>) +// CHECK: %[[SPLIT1:.+]]:2 = impulse.randomSplit %[[RNG]] : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>) +// CHECK-NEXT: %[[SPLIT2:.+]]:3 = impulse.randomSplit %[[SPLIT1]]#0 : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>, tensor<2xui64>) // // --- Extract initial position from trace tensor --- -// CHECK-NEXT: %[[Q0_SLICE:.+]] = enzyme.dynamic_slice %[[INIT_TRACE]], %[[C0]], %[[C0]] {slice_sizes = array} : (tensor<1x1xf64>, tensor, tensor) -> tensor<1x1xf64> -// CHECK-NEXT: %[[Q0:.+]] = enzyme.dynamic_update_slice %[[INIT_TRACE]], %[[Q0_SLICE]], %[[C0]], %[[C0]] : (tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x1xf64> +// CHECK-NEXT: %[[Q0_SLICE:.+]] = impulse.dynamic_slice %[[INIT_TRACE]], %[[C0]], %[[C0]] {slice_sizes = array} : (tensor<1x1xf64>, tensor, tensor) -> tensor<1x1xf64> +// CHECK-NEXT: %[[Q0:.+]] = impulse.dynamic_update_slice %[[INIT_TRACE]], %[[Q0_SLICE]], %[[C0]], %[[C0]] : (tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x1xf64> // // --- Constrain position and call generate for U0 --- -// CHECK-NEXT: %[[Q0_CONS_S:.+]] = enzyme.dynamic_slice %[[Q0]], %[[C0]], %[[C0]] {slice_sizes = array} -// CHECK-NEXT: %[[Q0_CONS:.+]] = enzyme.dynamic_update_slice %[[INIT_TRACE]], %[[Q0_CONS_S]], %[[C0]], %[[C0]] +// CHECK-NEXT: %[[Q0_CONS_S:.+]] = impulse.dynamic_slice %[[Q0]], %[[C0]], %[[C0]] {slice_sizes = array} +// CHECK-NEXT: %[[Q0_CONS:.+]] = impulse.dynamic_update_slice %[[INIT_TRACE]], %[[Q0_CONS_S]], %[[C0]], %[[C0]] // CHECK-NEXT: %[[GEN_INIT:.+]]:4 = call @test.generate{{.*}}(%[[Q0_CONS]], %[[SPLIT2]]#1, %[[MEAN]], %[[STDDEV]]) : (tensor<1x1xf64>, tensor<2xui64>, tensor, tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) // CHECK-NEXT: %[[U0:.+]] = arith.negf %[[GEN_INIT]]#1 : tensor // // --- Initial gradient via autodiff --- // CHECK-NEXT: %[[AD_INIT:.+]]:3 = enzyme.autodiff_region(%[[Q0]], %[[ONE]]) { // CHECK-NEXT: ^bb0(%[[AD_ARG:.+]]: tensor<1x1xf64>): -// CHECK-NEXT: %[[AD_SLICE:.+]] = enzyme.dynamic_slice %[[AD_ARG]], %[[C0]], %[[C0]] {slice_sizes = array} -// CHECK-NEXT: %[[AD_CONS:.+]] = enzyme.dynamic_update_slice %[[INIT_TRACE]], %[[AD_SLICE]], %[[C0]], %[[C0]] +// CHECK-NEXT: %[[AD_SLICE:.+]] = impulse.dynamic_slice %[[AD_ARG]], %[[C0]], %[[C0]] {slice_sizes = array} +// CHECK-NEXT: %[[AD_CONS:.+]] = impulse.dynamic_update_slice %[[INIT_TRACE]], %[[AD_SLICE]], %[[C0]], %[[C0]] // CHECK-NEXT: %[[AD_GEN:.+]]:4 = func.call @test.generate{{.*}}(%[[AD_CONS]], %[[SPLIT2]]#1, %[[MEAN]], %[[STDDEV]]) : (tensor<1x1xf64>, tensor<2xui64>, tensor, tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) // CHECK-NEXT: %[[AD_NEG:.+]] = arith.negf %[[AD_GEN]]#1 : tensor // CHECK-NEXT: enzyme.yield %[[AD_NEG]], %[[AD_GEN]]#2 : tensor, tensor<2xui64> // CHECK-NEXT: } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} // // --- Main sampling loop --- -// CHECK: %[[LOOP:.+]]:6 = enzyme.for_loop(%[[C0]] : tensor) to(%[[C10]] : tensor) step(%[[C1]] : tensor) iter_args(%[[Q0]], %[[AD_INIT]]#2, %[[U0]], %[[SPLIT2]]#0, %[[SAMPLES_INIT]], %[[ACC_INIT]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1> { +// CHECK: %[[LOOP:.+]]:6 = impulse.for(%[[C0]] : tensor) to(%[[C10]] : tensor) step(%[[C1]] : tensor) iter_args(%[[Q0]], %[[AD_INIT]]#2, %[[U0]], %[[SPLIT2]]#0, %[[SAMPLES_INIT]], %[[ACC_INIT]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1> { // CHECK-NEXT: ^bb0(%[[ITER:.+]]: tensor, %[[Q:.+]]: tensor<1x1xf64>, %[[GRAD:.+]]: tensor<1x1xf64>, %[[U:.+]]: tensor, %[[RNG_I:.+]]: tensor<2xui64>, %[[SAMP_BUF:.+]]: tensor<10x1xf64>, %[[ACC_BUF:.+]]: tensor<10xi1>): // // --- Sample momentum p ~ N(0, I) --- -// CHECK-NEXT: %[[RNG_S:.+]]:3 = enzyme.randomSplit %[[RNG_I]] : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>, tensor<2xui64>) -// CHECK-NEXT: %[[RNG_M:.+]]:2 = enzyme.randomSplit %[[RNG_S]]#1 : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>) -// CHECK-NEXT: %[[RNG_P:.+]], %[[P:.+]] = enzyme.random %[[RNG_M]]#0, %[[ZERO_F]], %[[ONE]] {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x1xf64>) +// CHECK-NEXT: %[[RNG_S:.+]]:3 = impulse.randomSplit %[[RNG_I]] : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>, tensor<2xui64>) +// CHECK-NEXT: %[[RNG_M:.+]]:2 = impulse.randomSplit %[[RNG_S]]#1 : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>) +// CHECK-NEXT: %[[RNG_P:.+]], %[[P:.+]] = impulse.random %[[RNG_M]]#0, %[[ZERO_F]], %[[ONE]] {rng_distribution = #impulse} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x1xf64>) // // --- Transform momentum by mass matrix sqrt: p_transformed = massMatrixSqrt @ p --- -// CHECK-NEXT: %[[P_XFORM:.+]] = enzyme.dot %[[P]], {{.+}} {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> +// CHECK-NEXT: %[[P_XFORM:.+]] = impulse.dot %[[P]], {{.+}} {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> // // --- Initial kinetic energy K0 = 0.5 * p_transformed^T * M^-1 * p_transformed --- -// CHECK-NEXT: %[[P_V:.+]] = enzyme.dot %[[P_XFORM]], {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> -// CHECK-NEXT: %[[KE0_DOT:.+]] = enzyme.dot %[[P_XFORM]], %[[P_V]] {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor +// CHECK-NEXT: %[[P_V:.+]] = impulse.dot %[[P_XFORM]], {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> +// CHECK-NEXT: %[[KE0_DOT:.+]] = impulse.dot %[[P_XFORM]], %[[P_V]] {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor // CHECK-NEXT: %[[KE0:.+]] = arith.mulf %[[KE0_DOT]], %[[HALF]] : tensor // // --- Initial Hamiltonian H0 = U + K --- // CHECK-NEXT: %[[H0:.+]] = arith.addf %[[U]], %[[KE0]] : tensor // // --- Leapfrog integration loop --- -// CHECK-NEXT: %[[LF:.+]]:5 = enzyme.for_loop(%[[C0]] : tensor) to(%[[C10]] : tensor) step(%[[C1]] : tensor) iter_args(%[[Q]], %[[P_XFORM]], %[[GRAD]], %[[U]], %[[RNG_S]]#2 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64> { +// CHECK-NEXT: %[[LF:.+]]:5 = impulse.for(%[[C0]] : tensor) to(%[[C10]] : tensor) step(%[[C1]] : tensor) iter_args(%[[Q]], %[[P_XFORM]], %[[GRAD]], %[[U]], %[[RNG_S]]#2 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64> { // CHECK-NEXT: ^bb0(%[[LF_I:.+]]: tensor, %[[LF_Q:.+]]: tensor<1x1xf64>, %[[LF_P:.+]]: tensor<1x1xf64>, %[[LF_G:.+]]: tensor<1x1xf64>, %[[LF_U:.+]]: tensor, %[[LF_RNG:.+]]: tensor<2xui64>): // // --- Leapfrog: direction selection --- -// CHECK-NEXT: %[[DIR:.+]] = enzyme.select %[[TRUE]], %[[EPS]], %[[NEG_EPS]] : (tensor, tensor, tensor) -> tensor +// CHECK-NEXT: %[[DIR:.+]] = impulse.select %[[TRUE]], %[[EPS]], %[[NEG_EPS]] : (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[DIR_BC:.+]] = "enzyme.broadcast"(%[[DIR]]) <{shape = array}> : (tensor) -> tensor<1x1xf64> // // --- Leapfrog: half step momentum p_half = p - (eps/2) * grad --- @@ -95,15 +95,15 @@ module { // CHECK-NEXT: %[[P_HALF:.+]] = arith.subf %[[LF_P]], %[[GRAD_SCALED]] : tensor<1x1xf64> // // --- Leapfrog: full step position q_new = q + eps * M^-1 * p_half --- -// CHECK-NEXT: %[[P_VINV:.+]] = enzyme.dot %[[P_HALF]], {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> +// CHECK-NEXT: %[[P_VINV:.+]] = impulse.dot %[[P_HALF]], {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> // CHECK-NEXT: %[[P_STEP:.+]] = arith.mulf %[[DIR_BC]], %[[P_VINV]] : tensor<1x1xf64> // CHECK-NEXT: %[[Q_NEW:.+]] = arith.addf %[[LF_Q]], %[[P_STEP]] : tensor<1x1xf64> // // --- Leapfrog: gradient at new position --- // CHECK-NEXT: %[[AD_LF:.+]]:3 = enzyme.autodiff_region(%[[Q_NEW]], %[[ONE]]) { // CHECK-NEXT: ^bb0(%[[AD_LF_ARG:.+]]: tensor<1x1xf64>): -// CHECK-NEXT: %[[AD_LF_SLICE:.+]] = enzyme.dynamic_slice %[[AD_LF_ARG]], %[[C0]], %[[C0]] {slice_sizes = array} -// CHECK-NEXT: %[[AD_LF_CONS:.+]] = enzyme.dynamic_update_slice %[[INIT_TRACE]], %[[AD_LF_SLICE]], %[[C0]], %[[C0]] +// CHECK-NEXT: %[[AD_LF_SLICE:.+]] = impulse.dynamic_slice %[[AD_LF_ARG]], %[[C0]], %[[C0]] {slice_sizes = array} +// CHECK-NEXT: %[[AD_LF_CONS:.+]] = impulse.dynamic_update_slice %[[INIT_TRACE]], %[[AD_LF_SLICE]], %[[C0]], %[[C0]] // CHECK-NEXT: %[[AD_LF_GEN:.+]]:4 = func.call @test.generate(%[[AD_LF_CONS]], %[[LF_RNG]], %[[MEAN]], %[[STDDEV]]) : (tensor<1x1xf64>, tensor<2xui64>, tensor, tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) // CHECK-NEXT: %[[AD_LF_NEG:.+]] = arith.negf %[[AD_LF_GEN]]#1 : tensor // CHECK-NEXT: enzyme.yield %[[AD_LF_NEG]], %[[AD_LF_GEN]]#2 : tensor, tensor<2xui64> @@ -114,12 +114,12 @@ module { // CHECK-NEXT: %[[P_NEW:.+]] = arith.subf %[[P_HALF]], %[[GRAD_NEW_SCALED]] : tensor<1x1xf64> // // --- Leapfrog yield --- -// CHECK-NEXT: enzyme.yield %[[Q_NEW]], %[[P_NEW]], %[[AD_LF]]#2, %[[AD_LF]]#0, %[[AD_LF]]#1 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64> +// CHECK-NEXT: impulse.yield %[[Q_NEW]], %[[P_NEW]], %[[AD_LF]]#2, %[[AD_LF]]#0, %[[AD_LF]]#1 : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64> // CHECK-NEXT: } // // --- Final kinetic energy K_new = 0.5 * p_new^T * M^-1 * p_new --- -// CHECK-NEXT: %[[LF_P_V:.+]] = enzyme.dot %[[LF]]#1, {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> -// CHECK-NEXT: %[[KE_DOT:.+]] = enzyme.dot %[[LF]]#1, %[[LF_P_V]] {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor +// CHECK-NEXT: %[[LF_P_V:.+]] = impulse.dot %[[LF]]#1, {{.+}} {{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> +// CHECK-NEXT: %[[KE_DOT:.+]] = impulse.dot %[[LF]]#1, %[[LF_P_V]] {{{.*}}lhs_contracting_dimensions = array{{.*}}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor // CHECK-NEXT: %[[KE:.+]] = arith.mulf %[[KE_DOT]], %[[HALF]] : tensor // // --- Final Hamiltonian H_new = U_new + K_new --- @@ -131,26 +131,26 @@ module { // CHECK-NEXT: %[[ACCEPT_PROB:.+]] = arith.minimumf %[[EXP_DH]], %[[ONE]] : tensor // // --- Draw uniform for MH --- -// CHECK-NEXT: %[[RNG_U:.+]], %[[UNIF:.+]] = enzyme.random %[[LF]]#4, %[[ZERO_F]], %[[ONE]] {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) +// CHECK-NEXT: %[[RNG_U:.+]], %[[UNIF:.+]] = impulse.random %[[LF]]#4, %[[ZERO_F]], %[[ONE]] {rng_distribution = #impulse} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) // // --- Accept comparison --- // CHECK-NEXT: %[[ACCEPTED:.+]] = arith.cmpf olt, %[[UNIF]], %[[ACCEPT_PROB]] : tensor // // --- Select q, grad, U based on acceptance --- -// CHECK-NEXT: %[[Q_SEL:.+]] = enzyme.select %[[ACCEPTED]], %[[LF]]#0, %[[Q]] : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> -// CHECK-NEXT: %[[GRAD_SEL:.+]] = enzyme.select %[[ACCEPTED]], %[[LF]]#2, %[[GRAD]] : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> -// CHECK-NEXT: %[[U_SEL:.+]] = enzyme.select %[[ACCEPTED]], %[[LF]]#3, %[[U]] : (tensor, tensor, tensor) -> tensor +// CHECK-NEXT: %[[Q_SEL:.+]] = impulse.select %[[ACCEPTED]], %[[LF]]#0, %[[Q]] : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> +// CHECK-NEXT: %[[GRAD_SEL:.+]] = impulse.select %[[ACCEPTED]], %[[LF]]#2, %[[GRAD]] : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> +// CHECK-NEXT: %[[U_SEL:.+]] = impulse.select %[[ACCEPTED]], %[[LF]]#3, %[[U]] : (tensor, tensor, tensor) -> tensor // // --- Store samples: conditional on iteration index --- // CHECK-NEXT: %[[STORE_COND:.+]] = arith.cmpi sge, %[[ITER]], %[[C0]] : tensor -// CHECK-NEXT: %[[SAMP_UPD:.+]] = enzyme.dynamic_update_slice %[[SAMP_BUF]], %[[Q_SEL]], %[[ITER]], %[[C0]] : (tensor<10x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<10x1xf64> -// CHECK-NEXT: %[[SAMP_SEL:.+]] = enzyme.select %[[STORE_COND]], %[[SAMP_UPD]], %[[SAMP_BUF]] : (tensor, tensor<10x1xf64>, tensor<10x1xf64>) -> tensor<10x1xf64> -// CHECK-NEXT: %[[ACC_1D:.+]] = enzyme.reshape %[[ACCEPTED]] : (tensor) -> tensor<1xi1> -// CHECK-NEXT: %[[ACC_UPD:.+]] = enzyme.dynamic_update_slice %[[ACC_BUF]], %[[ACC_1D]], %[[ITER]] : (tensor<10xi1>, tensor<1xi1>, tensor) -> tensor<10xi1> -// CHECK-NEXT: %[[ACC_SEL:.+]] = enzyme.select %[[STORE_COND]], %[[ACC_UPD]], %[[ACC_BUF]] : (tensor, tensor<10xi1>, tensor<10xi1>) -> tensor<10xi1> +// CHECK-NEXT: %[[SAMP_UPD:.+]] = impulse.dynamic_update_slice %[[SAMP_BUF]], %[[Q_SEL]], %[[ITER]], %[[C0]] : (tensor<10x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<10x1xf64> +// CHECK-NEXT: %[[SAMP_SEL:.+]] = impulse.select %[[STORE_COND]], %[[SAMP_UPD]], %[[SAMP_BUF]] : (tensor, tensor<10x1xf64>, tensor<10x1xf64>) -> tensor<10x1xf64> +// CHECK-NEXT: %[[ACC_1D:.+]] = impulse.reshape %[[ACCEPTED]] : (tensor) -> tensor<1xi1> +// CHECK-NEXT: %[[ACC_UPD:.+]] = impulse.dynamic_update_slice %[[ACC_BUF]], %[[ACC_1D]], %[[ITER]] : (tensor<10xi1>, tensor<1xi1>, tensor) -> tensor<10xi1> +// CHECK-NEXT: %[[ACC_SEL:.+]] = impulse.select %[[STORE_COND]], %[[ACC_UPD]], %[[ACC_BUF]] : (tensor, tensor<10xi1>, tensor<10xi1>) -> tensor<10xi1> // // --- Yield from sampling loop --- -// CHECK-NEXT: enzyme.yield %[[Q_SEL]], %[[GRAD_SEL]], %[[U_SEL]], %[[RNG_S]]#0, %[[SAMP_SEL]], %[[ACC_SEL]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1> +// CHECK-NEXT: impulse.yield %[[Q_SEL]], %[[GRAD_SEL]], %[[U_SEL]], %[[RNG_S]]#0, %[[SAMP_SEL]], %[[ACC_SEL]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1> // CHECK-NEXT: } // CHECK-NEXT: return %[[LOOP]]#4, %[[LOOP]]#5, %[[LOOP]]#3 : tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64> // CHECK-NEXT: } @@ -161,10 +161,10 @@ module { // CHECK-DAG: %[[G_C0:.+]] = arith.constant dense<0> : tensor // CHECK-DAG: %[[G_ZERO:.+]] = arith.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[G_TRACE_INIT:.+]] = arith.constant dense<0.000000e+00> : tensor<1x1xf64> -// CHECK: %[[G_SLICED:.+]] = enzyme.slice %[[G_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x1xf64>) -> tensor<1x1xf64> -// CHECK-NEXT: %[[G_VAL:.+]] = enzyme.reshape %[[G_SLICED]] : (tensor<1x1xf64>) -> tensor +// CHECK: %[[G_SLICED:.+]] = impulse.slice %[[G_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x1xf64>) -> tensor<1x1xf64> +// CHECK-NEXT: %[[G_VAL:.+]] = impulse.reshape %[[G_SLICED]] : (tensor<1x1xf64>) -> tensor // CHECK-NEXT: %[[G_LP:.+]] = call @logpdf(%[[G_VAL]], %[[G_ARG2]], %[[G_ARG3]]) : (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[G_W:.+]] = arith.addf %[[G_LP]], %[[G_ZERO]] : tensor -// CHECK-NEXT: %[[G_RS:.+]] = enzyme.reshape %[[G_VAL]] : (tensor) -> tensor<1x1xf64> -// CHECK-NEXT: %[[G_TR:.+]] = enzyme.dynamic_update_slice %[[G_TRACE_INIT]], %[[G_RS]], %[[G_C0]], %[[G_C0]] : (tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x1xf64> +// CHECK-NEXT: %[[G_RS:.+]] = impulse.reshape %[[G_VAL]] : (tensor) -> tensor<1x1xf64> +// CHECK-NEXT: %[[G_TR:.+]] = impulse.dynamic_update_slice %[[G_TRACE_INIT]], %[[G_RS]], %[[G_C0]], %[[G_C0]] : (tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x1xf64> // CHECK-NEXT: return %[[G_TR]], %[[G_W]], %[[G_ARG1]], %[[G_VAL]] : tensor<1x1xf64>, tensor, tensor<2xui64>, tensor diff --git a/enzyme/test/MLIR/ProbProg/mcmc_custom_logpdf.mlir b/enzyme/test/MLIR/ProbProg/mcmc_custom_logpdf.mlir index f2a6017f07c2..3d744873cb93 100644 --- a/enzyme/test/MLIR/ProbProg/mcmc_custom_logpdf.mlir +++ b/enzyme/test/MLIR/ProbProg/mcmc_custom_logpdf.mlir @@ -2,7 +2,7 @@ module { func.func @logpdf(%x : tensor<2xf64>) -> tensor { - %sum_sq = enzyme.dot %x, %x {lhs_batching_dimensions = array, rhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor + %sum_sq = impulse.dot %x, %x {lhs_batching_dimensions = array, rhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor %neg_half = arith.constant dense<-5.000000e-01> : tensor %result = arith.mulf %neg_half, %sum_sq : tensor return %result : tensor @@ -15,7 +15,7 @@ module { // CHECK: func.call @logpdf // CHECK-NEXT: %[[NEG:.+]] = arith.negf // CHECK-NEXT: enzyme.yield - // CHECK: enzyme.for_loop + // CHECK: impulse.for // CHECK: enzyme.autodiff_region // CHECK: func.call @logpdf // CHECK-NEXT: %{{.+}} = arith.negf @@ -23,9 +23,9 @@ module { func.func @nuts_logpdf(%rng : tensor<2xui64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { %init_pos = arith.constant dense<[[1.0, -1.0]]> : tensor<1x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = "enzyme.mcmc"(%rng, %step_size, %init_pos) { + %res:8 = "impulse.infer"(%rng, %step_size, %init_pos) { logpdf_fn = @logpdf, - nuts_config = #enzyme.nuts_config, + nuts_config = #impulse.nuts_config, name = "nuts_logpdf", selection = [], all_addresses = [], @@ -44,7 +44,7 @@ module { // CHECK: func.call @logpdf // CHECK-NEXT: %{{.+}} = arith.negf // CHECK-NEXT: enzyme.yield - // CHECK: enzyme.for_loop + // CHECK: impulse.for // CHECK: enzyme.autodiff_region // CHECK: func.call @logpdf // CHECK-NEXT: %{{.+}} = arith.negf @@ -52,9 +52,9 @@ module { func.func @hmc_logpdf(%rng : tensor<2xui64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { %init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = "enzyme.mcmc"(%rng, %step_size, %init_pos) { + %res:8 = "impulse.infer"(%rng, %step_size, %init_pos) { logpdf_fn = @logpdf, - hmc_config = #enzyme.hmc_config, + hmc_config = #impulse.hmc_config, name = "hmc_logpdf", selection = [], all_addresses = [], @@ -67,7 +67,7 @@ module { func.func @shifted_logpdf(%x : tensor<2xf64>, %mu : tensor<2xf64>) -> tensor { %diff = arith.subf %x, %mu : tensor<2xf64> - %sum_sq = enzyme.dot %diff, %diff {lhs_batching_dimensions = array, rhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor + %sum_sq = impulse.dot %diff, %diff {lhs_batching_dimensions = array, rhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor %neg_half = arith.constant dense<-5.000000e-01> : tensor %result = arith.mulf %neg_half, %sum_sq : tensor return %result : tensor @@ -80,12 +80,17 @@ module { // CHECK: func.call @shifted_logpdf // CHECK-NEXT: %[[NEG:.+]] = arith.negf // CHECK-NEXT: enzyme.yield + // CHECK: impulse.for + // CHECK: enzyme.autodiff_region + // CHECK: func.call @shifted_logpdf + // CHECK-NEXT: %{{.+}} = arith.negf + // CHECK-NEXT: enzyme.yield func.func @nuts_shifted_logpdf(%rng : tensor<2xui64>, %mu : tensor<2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { %init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = "enzyme.mcmc"(%rng, %mu, %step_size, %init_pos) { + %res:8 = "impulse.infer"(%rng, %mu, %step_size, %init_pos) { logpdf_fn = @shifted_logpdf, - nuts_config = #enzyme.nuts_config, + nuts_config = #impulse.nuts_config, name = "nuts_shifted_logpdf", selection = [], all_addresses = [], @@ -103,12 +108,17 @@ module { // CHECK: func.call @shifted_logpdf // CHECK-NEXT: %{{.+}} = arith.negf // CHECK-NEXT: enzyme.yield + // CHECK: impulse.for + // CHECK: enzyme.autodiff_region + // CHECK: func.call @shifted_logpdf + // CHECK-NEXT: %{{.+}} = arith.negf + // CHECK-NEXT: enzyme.yield func.func @hmc_shifted_logpdf(%rng : tensor<2xui64>, %mu : tensor<2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { %init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = "enzyme.mcmc"(%rng, %mu, %step_size, %init_pos) { + %res:8 = "impulse.infer"(%rng, %mu, %step_size, %init_pos) { logpdf_fn = @shifted_logpdf, - hmc_config = #enzyme.hmc_config, + hmc_config = #impulse.hmc_config, name = "hmc_shifted_logpdf", selection = [], all_addresses = [], @@ -124,7 +134,7 @@ module { %diff_sq = arith.mulf %diff, %diff : tensor<2xf64> %weighted = arith.mulf %precision, %diff_sq : tensor<2xf64> %ones = arith.constant dense<1.0> : tensor<2xf64> - %sum = enzyme.dot %ones, %weighted {lhs_batching_dimensions = array, rhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor + %sum = impulse.dot %ones, %weighted {lhs_batching_dimensions = array, rhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor %neg_half = arith.constant dense<-5.000000e-01> : tensor %result = arith.mulf %neg_half, %sum : tensor return %result : tensor @@ -137,12 +147,17 @@ module { // CHECK: func.call @anisotropic_logpdf // CHECK-NEXT: %[[NEG:.+]] = arith.negf // CHECK-NEXT: enzyme.yield + // CHECK: impulse.for + // CHECK: enzyme.autodiff_region + // CHECK: func.call @anisotropic_logpdf + // CHECK-NEXT: %{{.+}} = arith.negf + // CHECK-NEXT: enzyme.yield func.func @nuts_anisotropic_logpdf(%rng : tensor<2xui64>, %mu : tensor<2xf64>, %precision : tensor<2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { %init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = "enzyme.mcmc"(%rng, %mu, %precision, %step_size, %init_pos) { + %res:8 = "impulse.infer"(%rng, %mu, %precision, %step_size, %init_pos) { logpdf_fn = @anisotropic_logpdf, - nuts_config = #enzyme.nuts_config, + nuts_config = #impulse.nuts_config, name = "nuts_anisotropic_logpdf", selection = [], all_addresses = [], @@ -160,12 +175,17 @@ module { // CHECK: func.call @anisotropic_logpdf // CHECK-NEXT: %{{.+}} = arith.negf // CHECK-NEXT: enzyme.yield + // CHECK: impulse.for + // CHECK: enzyme.autodiff_region + // CHECK: func.call @anisotropic_logpdf + // CHECK-NEXT: %{{.+}} = arith.negf + // CHECK-NEXT: enzyme.yield func.func @hmc_anisotropic_logpdf(%rng : tensor<2xui64>, %mu : tensor<2xf64>, %precision : tensor<2xf64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { %init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = "enzyme.mcmc"(%rng, %mu, %precision, %step_size, %init_pos) { + %res:8 = "impulse.infer"(%rng, %mu, %precision, %step_size, %init_pos) { logpdf_fn = @anisotropic_logpdf, - hmc_config = #enzyme.hmc_config, + hmc_config = #impulse.hmc_config, name = "hmc_anisotropic_logpdf", selection = [], all_addresses = [], diff --git a/enzyme/test/MLIR/ProbProg/mcmc_sampling.mlir b/enzyme/test/MLIR/ProbProg/mcmc_sampling.mlir index ce15573d3bb8..18bc6a7eeacb 100644 --- a/enzyme/test/MLIR/ProbProg/mcmc_sampling.mlir +++ b/enzyme/test/MLIR/ProbProg/mcmc_sampling.mlir @@ -5,17 +5,17 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @test(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %s#0, %s#1 : tensor<2xui64>, tensor } func.func @sampling_basic(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>) { %init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace step_size = %step_size - { nuts_config = #enzyme.nuts_config, - name = "sampling_basic", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 0, num_samples = 10 } + { nuts_config = #impulse.nuts_config, + name = "sampling_basic", selection = [[#impulse.symbol<1>]], all_addresses = [[#impulse.symbol<1>]], num_warmup = 0, num_samples = 10 } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>, tensor) -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor<1x1xf64>) return %res#0, %res#1, %res#2 : tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64> } @@ -23,10 +23,10 @@ module { func.func @sampling_thinning(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<5x1xf64>, tensor<5xi1>, tensor<2xui64>) { %init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace step_size = %step_size - { nuts_config = #enzyme.nuts_config, - name = "sampling_thinning", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 0, num_samples = 10, thinning = 2 } + { nuts_config = #impulse.nuts_config, + name = "sampling_thinning", selection = [[#impulse.symbol<1>]], all_addresses = [[#impulse.symbol<1>]], num_warmup = 0, num_samples = 10, thinning = 2 } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>, tensor) -> (tensor<5x1xf64>, tensor<5xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor<1x1xf64>) return %res#0, %res#1, %res#2 : tensor<5x1xf64>, tensor<5xi1>, tensor<2xui64> } @@ -34,10 +34,10 @@ module { func.func @sampling_with_warmup(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>) { %init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace step_size = %step_size - { nuts_config = #enzyme.nuts_config, - name = "sampling_with_warmup", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 5, num_samples = 10 } + { nuts_config = #impulse.nuts_config, + name = "sampling_with_warmup", selection = [[#impulse.symbol<1>]], all_addresses = [[#impulse.symbol<1>]], num_warmup = 5, num_samples = 10 } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>, tensor) -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor<1x1xf64>) return %res#0, %res#1, %res#2 : tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64> } @@ -55,8 +55,8 @@ module { // CHECK-DAG: %[[INIT_TRACE:.+]] = arith.constant dense<0.000000e+00> : tensor<1x1xf64> // // --- RNG splits --- -// CHECK: enzyme.randomSplit -// CHECK: enzyme.randomSplit +// CHECK: impulse.randomSplit +// CHECK: impulse.randomSplit // // --- Initial gradient via autodiff --- // CHECK: enzyme.autodiff_region(%{{.+}}, %{{.+}}) { @@ -67,23 +67,23 @@ module { // CHECK: } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} // // --- Sampling loop: for i in 0..10 --- -// CHECK: %[[SLOOP:.+]]:6 = enzyme.for_loop(%[[C0]] : tensor) to(%[[C10]] : tensor) +// CHECK: %[[SLOOP:.+]]:6 = impulse.for(%[[C0]] : tensor) to(%[[C10]] : tensor) // CHECK-SAME: iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[SAMP_INIT]], %[[ACC_INIT]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1>) // CHECK-SAME: -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1> // CHECK: ^bb0(%[[S_ITER:.+]]: tensor, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor<2xui64>, %{{.+}}: tensor<10x1xf64>, %{{.+}}: tensor<10xi1>): // // --- Momentum sampling --- -// CHECK: enzyme.random {{.*}} {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x1xf64>) +// CHECK: impulse.random {{.*}} {rng_distribution = #impulse} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x1xf64>) // // --- Kinetic energy --- -// CHECK: enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK: impulse.dot {{.*}} lhs_contracting_dimensions = array // // --- NUTS tree building --- -// CHECK: enzyme.while_loop +// CHECK: impulse.while // // --- Sample storage: update samples buffer --- -// CHECK: enzyme.dynamic_update_slice %{{.+}}, %{{.+}}, %[[S_ITER]], %{{.+}} : (tensor<10x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<10x1xf64> -// CHECK: enzyme.yield %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1> +// CHECK: impulse.dynamic_update_slice %{{.+}}, %{{.+}}, %[[S_ITER]], %{{.+}} : (tensor<10x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<10x1xf64> +// CHECK: impulse.yield %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<10x1xf64>, tensor<10xi1> // CHECK: } // CHECK: return %[[SLOOP]]#4, %[[SLOOP]]#5, %[[SLOOP]]#3 : tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64> // CHECK: } @@ -95,7 +95,7 @@ module { // CHECK-SAME: -> (tensor<5x1xf64>, tensor<5xi1>, tensor<2xui64>) // // --- Thinning loop --- -// CHECK: enzyme.for_loop +// CHECK: impulse.for // CHECK-SAME: iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<5x1xf64>, tensor<5xi1>) // CHECK-SAME: -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<5x1xf64>, tensor<5xi1> // @@ -107,7 +107,7 @@ module { // CHECK: arith.andi // // --- Sample storage --- -// CHECK: enzyme.dynamic_update_slice %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : (tensor<5x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<5x1xf64> +// CHECK: impulse.dynamic_update_slice %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : (tensor<5x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<5x1xf64> // CHECK: return {{.*}} : tensor<5x1xf64>, tensor<5xi1>, tensor<2xui64> // CHECK: } @@ -118,7 +118,7 @@ module { // CHECK-SAME: -> (tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64>) // // --- Warmup loop --- -// CHECK: %[[WARMUP:.+]]:16 = enzyme.for_loop +// CHECK: %[[WARMUP:.+]]:16 = impulse.for // CHECK-SAME: iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor) // CHECK-SAME: -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor // @@ -130,27 +130,27 @@ module { // CHECK: math.exp // // --- Welford update --- -// CHECK: enzyme.reshape {{.*}} : (tensor<1x1xf64>) -> tensor<1xf64> +// CHECK: impulse.reshape {{.*}} : (tensor<1x1xf64>) -> tensor<1xf64> // CHECK: arith.subf {{.*}} : tensor<1xf64> // CHECK: arith.divf {{.*}} : tensor<1xf64> // CHECK: arith.addf {{.*}} : tensor<1xf64> // // --- Window boundary logic --- -// CHECK: enzyme.if +// CHECK: impulse.if // CHECK: math.sqrt -// CHECK: enzyme.yield +// CHECK: impulse.yield // CHECK: }, { -// CHECK: enzyme.yield +// CHECK: impulse.yield // CHECK: }) // // --- Post-warmup sampling --- -// CHECK: enzyme.for_loop +// CHECK: impulse.for // CHECK-SAME: iter_args(%[[WARMUP]]#0, %[[WARMUP]]#1, %[[WARMUP]]#2, %[[WARMUP]]#3 // CHECK-SAME: : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64> // // --- Momentum with adapted mass matrix --- -// CHECK: enzyme.dot {{.*}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> -// CHECK: enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK: impulse.dot {{.*}} : (tensor<1x1xf64>, tensor<1x1xf64>) -> tensor<1x1xf64> +// CHECK: impulse.dot {{.*}} lhs_contracting_dimensions = array // // CHECK: return {{.*}} : tensor<10x1xf64>, tensor<10xi1>, tensor<2xui64> // CHECK: } diff --git a/enzyme/test/MLIR/ProbProg/mcmc_strong_zero.mlir b/enzyme/test/MLIR/ProbProg/mcmc_strong_zero.mlir index d556904f5e0c..63537f818d74 100644 --- a/enzyme/test/MLIR/ProbProg/mcmc_strong_zero.mlir +++ b/enzyme/test/MLIR/ProbProg/mcmc_strong_zero.mlir @@ -2,7 +2,7 @@ module { func.func @logpdf(%x : tensor<2xf64>) -> tensor { - %sum_sq = enzyme.dot %x, %x {lhs_batching_dimensions = array, rhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor + %sum_sq = impulse.dot %x, %x {lhs_batching_dimensions = array, rhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor %neg_half = arith.constant dense<-5.000000e-01> : tensor %result = arith.mulf %neg_half, %sum_sq : tensor return %result : tensor @@ -11,15 +11,15 @@ module { // CHECK-LABEL: func.func @nuts_strong_zero // CHECK: enzyme.autodiff_region // CHECK: strong_zero = true - // CHECK: enzyme.for_loop + // CHECK: impulse.for // CHECK: enzyme.autodiff_region // CHECK: strong_zero = true func.func @nuts_strong_zero(%rng : tensor<2xui64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { %init_pos = arith.constant dense<[[1.0, -1.0]]> : tensor<1x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = "enzyme.mcmc"(%rng, %step_size, %init_pos) { + %res:8 = "impulse.infer"(%rng, %step_size, %init_pos) { logpdf_fn = @logpdf, - nuts_config = #enzyme.nuts_config, + nuts_config = #impulse.nuts_config, autodiff_attrs = {strong_zero = true}, name = "nuts_strong_zero", selection = [], @@ -34,16 +34,16 @@ module { // CHECK-LABEL: func.func @nuts_default // CHECK: enzyme.autodiff_region // CHECK-NOT: strong_zero - // CHECK: enzyme.for_loop + // CHECK: impulse.for // CHECK: enzyme.autodiff_region // CHECK-NOT: strong_zero // CHECK: enzyme.yield func.func @nuts_default(%rng : tensor<2xui64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { %init_pos = arith.constant dense<[[1.0, -1.0]]> : tensor<1x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = "enzyme.mcmc"(%rng, %step_size, %init_pos) { + %res:8 = "impulse.infer"(%rng, %step_size, %init_pos) { logpdf_fn = @logpdf, - nuts_config = #enzyme.nuts_config, + nuts_config = #impulse.nuts_config, name = "nuts_default", selection = [], all_addresses = [], @@ -57,15 +57,15 @@ module { // CHECK-LABEL: func.func @hmc_strong_zero // CHECK: enzyme.autodiff_region // CHECK: strong_zero = true - // CHECK: enzyme.for_loop + // CHECK: impulse.for // CHECK: enzyme.autodiff_region // CHECK: strong_zero = true func.func @hmc_strong_zero(%rng : tensor<2xui64>) -> (tensor<1x2xf64>, tensor<1xi1>, tensor<2xui64>) { %init_pos = arith.constant dense<[[0.5, -0.5]]> : tensor<1x2xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = "enzyme.mcmc"(%rng, %step_size, %init_pos) { + %res:8 = "impulse.infer"(%rng, %step_size, %init_pos) { logpdf_fn = @logpdf, - hmc_config = #enzyme.hmc_config, + hmc_config = #impulse.hmc_config, autodiff_attrs = {strong_zero = true}, name = "hmc_strong_zero", selection = [], diff --git a/enzyme/test/MLIR/ProbProg/mcmc_warmup.mlir b/enzyme/test/MLIR/ProbProg/mcmc_warmup.mlir index 505c78b8fade..21262b725420 100644 --- a/enzyme/test/MLIR/ProbProg/mcmc_warmup.mlir +++ b/enzyme/test/MLIR/ProbProg/mcmc_warmup.mlir @@ -5,7 +5,7 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @test(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %s#0, %s#1 : tensor<2xui64>, tensor } @@ -13,10 +13,10 @@ module { func.func @warmup_both(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>) { %init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace step_size = %step_size - { nuts_config = #enzyme.nuts_config, - name = "warmup_both", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 10, num_samples = 1 } + { nuts_config = #impulse.nuts_config, + name = "warmup_both", selection = [[#impulse.symbol<1>]], all_addresses = [[#impulse.symbol<1>]], num_warmup = 10, num_samples = 1 } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>, tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor<1x1xf64>) return %res#0, %res#1, %res#2 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> } @@ -25,10 +25,10 @@ module { func.func @warmup_step_only(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>) { %init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace step_size = %step_size - { nuts_config = #enzyme.nuts_config, - name = "warmup_step_only", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 10, num_samples = 1 } + { nuts_config = #impulse.nuts_config, + name = "warmup_step_only", selection = [[#impulse.symbol<1>]], all_addresses = [[#impulse.symbol<1>]], num_warmup = 10, num_samples = 1 } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>, tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor<1x1xf64>) return %res#0, %res#1, %res#2 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> } @@ -37,10 +37,10 @@ module { func.func @warmup_mass_only(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>) { %init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace step_size = %step_size - { nuts_config = #enzyme.nuts_config, - name = "warmup_mass_only", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 10, num_samples = 1 } + { nuts_config = #impulse.nuts_config, + name = "warmup_mass_only", selection = [[#impulse.symbol<1>]], all_addresses = [[#impulse.symbol<1>]], num_warmup = 10, num_samples = 1 } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>, tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor<1x1xf64>) return %res#0, %res#1, %res#2 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> } @@ -49,10 +49,10 @@ module { func.func @warmup_none(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>) { %init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace step_size = %step_size - { nuts_config = #enzyme.nuts_config, - name = "warmup_none", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 10, num_samples = 1 } + { nuts_config = #impulse.nuts_config, + name = "warmup_none", selection = [[#impulse.symbol<1>]], all_addresses = [[#impulse.symbol<1>]], num_warmup = 10, num_samples = 1 } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>, tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor<1x1xf64>) return %res#0, %res#1, %res#2 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> } @@ -92,14 +92,14 @@ module { // CHECK-DAG: %[[INIT_TRACE:.+]] = arith.constant dense<0.000000e+00> : tensor<1x1xf64> // // --- Init: RNG splits --- -// CHECK: %[[SPLIT1:.+]]:2 = enzyme.randomSplit %[[RNG]] -// CHECK-NEXT: %[[SPLIT2:.+]]:3 = enzyme.randomSplit %[[SPLIT1]]#0 +// CHECK: %[[SPLIT1:.+]]:2 = impulse.randomSplit %[[RNG]] +// CHECK-NEXT: %[[SPLIT2:.+]]:3 = impulse.randomSplit %[[SPLIT1]]#0 // // --- Extract initial position --- -// CHECK-NEXT: %{{.+}} = enzyme.dynamic_slice %[[INIT_TRACE]], %[[C0]], %[[C0]] {slice_sizes = array} -// CHECK-NEXT: %{{.+}} = enzyme.dynamic_update_slice %[[INIT_TRACE]], %{{.+}}, %[[C0]], %[[C0]] -// CHECK-NEXT: %{{.+}} = enzyme.dynamic_slice %{{.+}}, %[[C0]], %[[C0]] {slice_sizes = array} -// CHECK-NEXT: %{{.+}} = enzyme.dynamic_update_slice %[[INIT_TRACE]], %{{.+}}, %[[C0]], %[[C0]] +// CHECK-NEXT: %{{.+}} = impulse.dynamic_slice %[[INIT_TRACE]], %[[C0]], %[[C0]] {slice_sizes = array} +// CHECK-NEXT: %{{.+}} = impulse.dynamic_update_slice %[[INIT_TRACE]], %{{.+}}, %[[C0]], %[[C0]] +// CHECK-NEXT: %{{.+}} = impulse.dynamic_slice %{{.+}}, %[[C0]], %[[C0]] {slice_sizes = array} +// CHECK-NEXT: %{{.+}} = impulse.dynamic_update_slice %[[INIT_TRACE]], %{{.+}}, %[[C0]], %[[C0]] // // --- Initial potential U0 --- // CHECK-NEXT: %[[GEN0:.+]]:4 = call @test.generate{{.*}}(%{{.+}}, %[[SPLIT2]]#1, %[[MEAN]], %[[STDDEV]]) @@ -108,29 +108,29 @@ module { // --- Initial gradient via autodiff --- // CHECK-NEXT: %[[AD0:.+]]:3 = enzyme.autodiff_region(%{{.+}}, %[[ONE]]) { // CHECK-NEXT: ^bb0(%{{.+}}: tensor<1x1xf64>): -// CHECK-NEXT: %{{.+}} = enzyme.dynamic_slice -// CHECK-NEXT: %{{.+}} = enzyme.dynamic_update_slice +// CHECK-NEXT: %{{.+}} = impulse.dynamic_slice +// CHECK-NEXT: %{{.+}} = impulse.dynamic_update_slice // CHECK-NEXT: %{{.+}} = func.call @test.generate // CHECK-NEXT: %{{.+}} = arith.negf // CHECK-NEXT: enzyme.yield // CHECK-NEXT: } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} // // --- Warmup loop: 16 iter_args --- -// CHECK-NEXT: %[[WARMUP:.+]]:16 = enzyme.for_loop(%[[C0]] : tensor) to(%[[C10]] : tensor) step(%[[C1]] : tensor) iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[ZERO_F]], %[[ZERO_F]], %[[ZERO_F]], %[[C0]], %[[EPS_INIT]], %[[ZERO_1D]], %[[ZERO_1D]], %[[C0]], %[[C0]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor { +// CHECK-NEXT: %[[WARMUP:.+]]:16 = impulse.for(%[[C0]] : tensor) to(%[[C10]] : tensor) step(%[[C1]] : tensor) iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[ZERO_F]], %[[ZERO_F]], %[[ZERO_F]], %[[C0]], %[[EPS_INIT]], %[[ZERO_1D]], %[[ZERO_1D]], %[[C0]], %[[C0]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor { // CHECK-NEXT: ^bb0(%[[WI:.+]]: tensor, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor<2xui64>, %{{.+}}: tensor, %[[INV_MASS:.+]]: tensor<1x1xf64>, %[[MASS_SQRT:.+]]: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor<1xf64>, %{{.+}}: tensor<1xf64>, %{{.+}}: tensor, %{{.+}}: tensor): // // --- Momentum sampling with mass matrix --- -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}}, %{{.+}} = enzyme.random %{{.+}}, %[[ZERO_F]], %[[ONE]] {rng_distribution = #enzyme} -// CHECK-NEXT: %{{.+}} = enzyme.dot %{{.+}}, %[[MASS_SQRT]] {{.*}} -// CHECK-NEXT: %{{.+}} = enzyme.dot %{{.+}}, %[[INV_MASS]] {{.*}} -// CHECK-NEXT: %{{.+}} = enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}}, %{{.+}} = impulse.random %{{.+}}, %[[ZERO_F]], %[[ONE]] {rng_distribution = #impulse} +// CHECK-NEXT: %{{.+}} = impulse.dot %{{.+}}, %[[MASS_SQRT]] {{.*}} +// CHECK-NEXT: %{{.+}} = impulse.dot %{{.+}}, %[[INV_MASS]] {{.*}} +// CHECK-NEXT: %{{.+}} = impulse.dot {{.*}} lhs_contracting_dimensions = array // CHECK-NEXT: %{{.+}} = arith.mulf %{{.+}}, %[[HALF]] // CHECK-NEXT: %{{.+}} = arith.addf // // --- NUTS tree building (tested in nuts_kernel.mlir) --- -// CHECK-NEXT: %[[TREE:.+]]:18 = enzyme.while_loop +// CHECK-NEXT: %[[TREE:.+]]:18 = impulse.while // // --- Dual averaging: compute mean accept prob --- // CHECK: %{{.+}} = arith.maxsi %[[TREE]]#15, %[[C1]] @@ -167,7 +167,7 @@ module { // // --- Window boundary check: select averaged at boundary --- // CHECK-NEXT: %[[IS_END:.+]] = arith.cmpi eq, %[[WI]], %[[C9]] -// CHECK-NEXT: %{{.+}} = enzyme.select %[[IS_END]], %[[SS_AVG]], %[[SS_CUR]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[IS_END]], %[[SS_AVG]], %[[SS_CUR]] // // --- Clamp step size --- // CHECK-NEXT: %{{.+}} = arith.maximumf %{{.+}}, %[[FMIN]] @@ -179,7 +179,7 @@ module { // CHECK-NEXT: %[[IN_WINDOW:.+]] = arith.andi // // --- Welford update --- -// CHECK-NEXT: %{{.+}} = enzyme.reshape %[[TREE]]#6 : (tensor<1x1xf64>) -> tensor<1xf64> +// CHECK-NEXT: %{{.+}} = impulse.reshape %[[TREE]]#6 : (tensor<1x1xf64>) -> tensor<1xf64> // CHECK-NEXT: %{{.+}} = arith.addi // CHECK-NEXT: %{{.+}} = arith.sitofp // CHECK-NEXT: %{{.+}} = "enzyme.broadcast" @@ -191,20 +191,20 @@ module { // CHECK-NEXT: %{{.+}} = arith.addf // // --- Conditional Welford update --- -// CHECK-NEXT: %{{.+}} = enzyme.select %[[IN_WINDOW]] -// CHECK-NEXT: %{{.+}} = enzyme.select %[[IN_WINDOW]] -// CHECK-NEXT: %{{.+}} = enzyme.select %[[IN_WINDOW]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[IN_WINDOW]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[IN_WINDOW]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[IN_WINDOW]] // // --- Window schedule logic --- // CHECK-NEXT: %{{.+}} = arith.cmpi eq // CHECK-NEXT: %{{.+}} = arith.cmpi eq // CHECK-NEXT: %{{.+}} = arith.andi // CHECK-NEXT: %{{.+}} = arith.addi -// CHECK-NEXT: %{{.+}} = enzyme.select +// CHECK-NEXT: %{{.+}} = impulse.select // CHECK-NEXT: %{{.+}} = arith.andi // -// --- Window boundary: finalize Welford & reinit DA (enzyme.if) --- -// CHECK-NEXT: %{{.+}}:10 = enzyme.if +// --- Window boundary: finalize Welford & reinit DA (impulse.if) --- +// CHECK-NEXT: %{{.+}}:10 = impulse.if // CHECK: arith.subi // CHECK-NEXT: arith.sitofp // CHECK-NEXT: "enzyme.broadcast" @@ -221,42 +221,42 @@ module { // CHECK-NEXT: arith.divf %[[ONE_1D]] // CHECK-NEXT: math.log // CHECK-NEXT: arith.addf %{{.+}}, %[[LOG10]] -// CHECK-NEXT: enzyme.yield +// CHECK-NEXT: impulse.yield // CHECK-NEXT: }, { -// CHECK-NEXT: enzyme.yield +// CHECK-NEXT: impulse.yield // CHECK-NEXT: }) // // --- Warmup loop yield --- -// CHECK-NEXT: enzyme.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor +// CHECK-NEXT: impulse.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor // CHECK-NEXT: } // // --- Post-warmup sampling loop: 6 iter_args with adapted params --- -// CHECK-NEXT: %[[SLOOP:.+]]:6 = enzyme.for_loop(%[[C0]] : tensor) to(%[[C1]] : tensor) step(%[[C1]] : tensor) iter_args(%[[WARMUP]]#0, %[[WARMUP]]#1, %[[WARMUP]]#2, %[[WARMUP]]#3, %{{.+}}, %[[ACC_INIT]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> { +// CHECK-NEXT: %[[SLOOP:.+]]:6 = impulse.for(%[[C0]] : tensor) to(%[[C1]] : tensor) step(%[[C1]] : tensor) iter_args(%[[WARMUP]]#0, %[[WARMUP]]#1, %[[WARMUP]]#2, %[[WARMUP]]#3, %{{.+}}, %[[ACC_INIT]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> { // CHECK-NEXT: ^bb0(%[[SI:.+]]: tensor, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor<2xui64>, %[[S_SAMP:.+]]: tensor<1x1xf64>, %[[S_ACC:.+]]: tensor<1xi1>): // // --- Momentum with adapted mass matrix from warmup --- -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}}, %{{.+}} = enzyme.random {{.*}} {rng_distribution = #enzyme} -// CHECK-NEXT: %{{.+}} = enzyme.dot %{{.+}}, %[[WARMUP]]#6 {{.*}} -// CHECK-NEXT: %{{.+}} = enzyme.dot %{{.+}}, %[[WARMUP]]#5 {{.*}} -// CHECK-NEXT: %{{.+}} = enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}}, %{{.+}} = impulse.random {{.*}} {rng_distribution = #impulse} +// CHECK-NEXT: %{{.+}} = impulse.dot %{{.+}}, %[[WARMUP]]#6 {{.*}} +// CHECK-NEXT: %{{.+}} = impulse.dot %{{.+}}, %[[WARMUP]]#5 {{.*}} +// CHECK-NEXT: %{{.+}} = impulse.dot {{.*}} lhs_contracting_dimensions = array // CHECK-NEXT: %{{.+}} = arith.mulf // CHECK-NEXT: %{{.+}} = arith.addf // // --- NUTS tree building --- -// CHECK-NEXT: %[[S_TREE:.+]]:18 = enzyme.while_loop +// CHECK-NEXT: %[[S_TREE:.+]]:18 = impulse.while // // --- Sample storage --- // CHECK: %[[S_STORE:.+]] = arith.cmpi sge, %[[SI]], %[[C0]] -// CHECK-NEXT: %{{.+}} = enzyme.dynamic_update_slice %[[S_SAMP]], %[[S_TREE]]#6, %[[SI]], %[[C0]] -// CHECK-NEXT: %{{.+}} = enzyme.select %[[S_STORE]] -// CHECK-NEXT: %{{.+}} = enzyme.reshape %[[TRUE]] : (tensor) -> tensor<1xi1> -// CHECK-NEXT: %{{.+}} = enzyme.dynamic_update_slice %[[S_ACC]], %{{.+}}, %[[SI]] -// CHECK-NEXT: %{{.+}} = enzyme.select %[[S_STORE]] +// CHECK-NEXT: %{{.+}} = impulse.dynamic_update_slice %[[S_SAMP]], %[[S_TREE]]#6, %[[SI]], %[[C0]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[S_STORE]] +// CHECK-NEXT: %{{.+}} = impulse.reshape %[[TRUE]] : (tensor) -> tensor<1xi1> +// CHECK-NEXT: %{{.+}} = impulse.dynamic_update_slice %[[S_ACC]], %{{.+}}, %[[SI]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[S_STORE]] // // --- Sampling yield --- -// CHECK-NEXT: enzyme.yield %[[S_TREE]]#6, %[[S_TREE]]#7, %[[S_TREE]]#8, %{{.+}}, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> +// CHECK-NEXT: impulse.yield %[[S_TREE]]#6, %[[S_TREE]]#7, %[[S_TREE]]#8, %{{.+}}, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> // CHECK-NEXT: } // CHECK-NEXT: return %[[SLOOP]]#4, %[[SLOOP]]#5, %[[SLOOP]]#3 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> // CHECK-NEXT: } @@ -285,28 +285,28 @@ module { // CHECK-DAG: %[[S_ZERO_F:.+]] = arith.constant dense<0.000000e+00> : tensor // // --- Init --- -// CHECK: enzyme.randomSplit -// CHECK: enzyme.dynamic_update_slice +// CHECK: impulse.randomSplit +// CHECK: impulse.dynamic_update_slice // CHECK: arith.negf // CHECK: enzyme.autodiff_region // CHECK: } // // --- Warmup loop: 13 iter_args (no Welford mean/m2/n) --- -// CHECK: %[[S_WARMUP:.+]]:13 = enzyme.for_loop(%[[S_C0]] : tensor) to(%[[S_C10]] : tensor) step(%[[S_C1]] : tensor) iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[S_ONE_1D]], %[[S_ONE_1D]], %[[S_ZERO_F]], %[[S_ZERO_F]], %[[S_ZERO_F]], %[[S_C0]], %[[S_EPS_INIT]], %[[S_C0]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor { +// CHECK: %[[S_WARMUP:.+]]:13 = impulse.for(%[[S_C0]] : tensor) to(%[[S_C10]] : tensor) step(%[[S_C1]] : tensor) iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[S_ONE_1D]], %[[S_ONE_1D]], %[[S_ZERO_F]], %[[S_ZERO_F]], %[[S_ZERO_F]], %[[S_C0]], %[[S_EPS_INIT]], %[[S_C0]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor { // CHECK-NEXT: ^bb0(%[[S_WI:.+]]: tensor, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor<2xui64>, %{{.+}}: tensor, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor): // // --- Momentum sampling --- -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}}, %{{.+}} = enzyme.random {{.*}} {rng_distribution = #enzyme} -// CHECK-NEXT: %{{.+}} = enzyme.dot -// CHECK-NEXT: %{{.+}} = enzyme.dot -// CHECK-NEXT: %{{.+}} = enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}}, %{{.+}} = impulse.random {{.*}} {rng_distribution = #impulse} +// CHECK-NEXT: %{{.+}} = impulse.dot +// CHECK-NEXT: %{{.+}} = impulse.dot +// CHECK-NEXT: %{{.+}} = impulse.dot {{.*}} lhs_contracting_dimensions = array // CHECK-NEXT: %{{.+}} = arith.mulf // CHECK-NEXT: %{{.+}} = arith.addf // // --- NUTS tree building --- -// CHECK-NEXT: %[[S_WTREE:.+]]:18 = enzyme.while_loop +// CHECK-NEXT: %[[S_WTREE:.+]]:18 = impulse.while // // --- Dual averaging update (same as warmup_both) --- // CHECK: %{{.+}} = arith.maxsi %[[S_WTREE]]#15, %[[S_C1]] @@ -335,7 +335,7 @@ module { // // --- Window boundary: select averaged at end --- // CHECK-NEXT: %{{.+}} = arith.cmpi eq, %[[S_WI]], %[[S_C9]] -// CHECK-NEXT: %{{.+}} = enzyme.select +// CHECK-NEXT: %{{.+}} = impulse.select // CHECK-NEXT: %{{.+}} = arith.maximumf %{{.+}}, %[[S_FMIN]] // CHECK-NEXT: %{{.+}} = arith.minimumf %{{.+}}, %[[S_FMAX]] // @@ -348,25 +348,25 @@ module { // CHECK-NEXT: %{{.+}} = arith.cmpi eq // CHECK-NEXT: %{{.+}} = arith.andi // CHECK-NEXT: %{{.+}} = arith.addi -// CHECK-NEXT: %{{.+}} = enzyme.select +// CHECK-NEXT: %{{.+}} = impulse.select // CHECK-NEXT: %{{.+}} = arith.andi // -// --- Window boundary enzyme.if: 7 results (no Welford) --- -// CHECK-NEXT: %{{.+}}:7 = enzyme.if +// --- Window boundary impulse.if: 7 results (no Welford) --- +// CHECK-NEXT: %{{.+}}:7 = impulse.if // CHECK: math.log // CHECK: arith.addf %{{.+}}, %[[S_LOG10]] -// CHECK: enzyme.yield +// CHECK: impulse.yield // CHECK-NEXT: }, { -// CHECK-NEXT: enzyme.yield +// CHECK-NEXT: impulse.yield // CHECK-NEXT: }) // // --- Warmup yield: 13 values --- -// CHECK-NEXT: enzyme.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor +// CHECK-NEXT: impulse.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor // CHECK-NEXT: } // // --- Post-warmup sampling loop --- -// CHECK-NEXT: %[[S_SLOOP:.+]]:6 = enzyme.for_loop(%[[S_C0]] : tensor) to(%[[S_C1]] : tensor) step(%[[S_C1]] : tensor) iter_args(%[[S_WARMUP]]#0, %[[S_WARMUP]]#1, %[[S_WARMUP]]#2, %[[S_WARMUP]]#3, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> { -// CHECK: enzyme.while_loop +// CHECK-NEXT: %[[S_SLOOP:.+]]:6 = impulse.for(%[[S_C0]] : tensor) to(%[[S_C1]] : tensor) step(%[[S_C1]] : tensor) iter_args(%[[S_WARMUP]]#0, %[[S_WARMUP]]#1, %[[S_WARMUP]]#2, %[[S_WARMUP]]#3, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> { +// CHECK: impulse.while // CHECK: return %[[S_SLOOP]]#4, %[[S_SLOOP]]#5, %[[S_SLOOP]]#3 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> // CHECK-NEXT: } @@ -389,31 +389,31 @@ module { // CHECK-DAG: %[[M_EPS_INIT:.+]] = arith.constant dense<4.44{{.+}}> : tensor // // --- Init --- -// CHECK: enzyme.randomSplit +// CHECK: impulse.randomSplit // CHECK: arith.negf // CHECK: enzyme.autodiff_region // CHECK: } attributes // // --- Warmup loop: 16 iter_args (includes Welford state) --- -// CHECK-NEXT: %[[M_WARMUP:.+]]:16 = enzyme.for_loop(%[[M_C0]] : tensor) to(%[[M_C10]] : tensor) step(%[[M_C1]] : tensor) iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[M_C0]], %[[M_EPS_INIT]], %[[M_ZERO_1D]], %[[M_ZERO_1D]], %[[M_C0]], %[[M_C0]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor { +// CHECK-NEXT: %[[M_WARMUP:.+]]:16 = impulse.for(%[[M_C0]] : tensor) to(%[[M_C10]] : tensor) step(%[[M_C1]] : tensor) iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[M_C0]], %[[M_EPS_INIT]], %[[M_ZERO_1D]], %[[M_ZERO_1D]], %[[M_C0]], %[[M_C0]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor { // CHECK-NEXT: ^bb0(%[[M_WI:.+]]: tensor, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor<2xui64>, %{{.+}}: tensor, %[[M_INV:.+]]: tensor<1x1xf64>, %[[M_SQRT:.+]]: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor<1xf64>, %{{.+}}: tensor<1xf64>, %{{.+}}: tensor, %{{.+}}: tensor): // // --- Momentum sampling with mass matrix --- -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}}, %{{.+}} = enzyme.random {{.*}} {rng_distribution = #enzyme} -// CHECK-NEXT: %{{.+}} = enzyme.dot %{{.+}}, %[[M_SQRT]] {{.*}} -// CHECK-NEXT: %{{.+}} = enzyme.dot %{{.+}}, %[[M_INV]] {{.*}} -// CHECK-NEXT: %{{.+}} = enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}}, %{{.+}} = impulse.random {{.*}} {rng_distribution = #impulse} +// CHECK-NEXT: %{{.+}} = impulse.dot %{{.+}}, %[[M_SQRT]] {{.*}} +// CHECK-NEXT: %{{.+}} = impulse.dot %{{.+}}, %[[M_INV]] {{.*}} +// CHECK-NEXT: %{{.+}} = impulse.dot {{.*}} lhs_contracting_dimensions = array // CHECK-NEXT: %{{.+}} = arith.mulf // CHECK-NEXT: %{{.+}} = arith.addf // // --- NUTS tree building --- -// CHECK-NEXT: %[[M_TREE:.+]]:18 = enzyme.while_loop +// CHECK-NEXT: %[[M_TREE:.+]]:18 = impulse.while // // --- Step size: trivially select same value (no DA) --- // CHECK: %[[M_IS_END:.+]] = arith.cmpi eq, %[[M_WI]], %[[M_C9]] -// CHECK-NEXT: %{{.+}} = enzyme.select %[[M_IS_END]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[M_IS_END]] // CHECK-NEXT: %{{.+}} = arith.maximumf %{{.+}}, %[[M_FMIN]] // CHECK-NEXT: %{{.+}} = arith.minimumf %{{.+}}, %[[M_FMAX]] // @@ -423,7 +423,7 @@ module { // CHECK-NEXT: %[[M_IN_WIN:.+]] = arith.andi // // --- Welford update --- -// CHECK-NEXT: %{{.+}} = enzyme.reshape %[[M_TREE]]#6 : (tensor<1x1xf64>) -> tensor<1xf64> +// CHECK-NEXT: %{{.+}} = impulse.reshape %[[M_TREE]]#6 : (tensor<1x1xf64>) -> tensor<1xf64> // CHECK-NEXT: %{{.+}} = arith.addi // CHECK-NEXT: %{{.+}} = arith.sitofp // CHECK-NEXT: %{{.+}} = "enzyme.broadcast" @@ -435,42 +435,42 @@ module { // CHECK-NEXT: %{{.+}} = arith.addf // // --- Conditional Welford update --- -// CHECK-NEXT: %{{.+}} = enzyme.select %[[M_IN_WIN]] -// CHECK-NEXT: %{{.+}} = enzyme.select %[[M_IN_WIN]] -// CHECK-NEXT: %{{.+}} = enzyme.select %[[M_IN_WIN]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[M_IN_WIN]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[M_IN_WIN]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[M_IN_WIN]] // // --- Window schedule --- // CHECK-NEXT: %{{.+}} = arith.cmpi eq // CHECK-NEXT: %{{.+}} = arith.cmpi eq // CHECK-NEXT: %{{.+}} = arith.andi // CHECK-NEXT: %{{.+}} = arith.addi -// CHECK-NEXT: %{{.+}} = enzyme.select +// CHECK-NEXT: %{{.+}} = impulse.select // CHECK-NEXT: %{{.+}} = arith.andi // -// --- Window boundary enzyme.if: 10 results (includes Welford finalization) --- -// CHECK-NEXT: %{{.+}}:10 = enzyme.if +// --- Window boundary impulse.if: 10 results (includes Welford finalization) --- +// CHECK-NEXT: %{{.+}}:10 = impulse.if // CHECK: math.sqrt // CHECK: arith.divf %[[M_ONE_1D]] -// CHECK: enzyme.yield +// CHECK: impulse.yield // CHECK-NEXT: }, { -// CHECK-NEXT: enzyme.yield +// CHECK-NEXT: impulse.yield // CHECK-NEXT: }) // // --- Warmup yield: 16 values --- -// CHECK-NEXT: enzyme.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor +// CHECK-NEXT: impulse.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor<1xf64>, tensor<1xf64>, tensor, tensor // CHECK-NEXT: } // // --- Post-warmup sampling with adapted mass matrix --- -// CHECK-NEXT: %[[M_SLOOP:.+]]:6 = enzyme.for_loop(%[[M_C0]] : tensor) to(%[[M_C1]] : tensor) step(%[[M_C1]] : tensor) iter_args(%[[M_WARMUP]]#0, %[[M_WARMUP]]#1, %[[M_WARMUP]]#2, %[[M_WARMUP]]#3, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> { -// CHECK: enzyme.dot %{{.+}}, %[[M_WARMUP]]#6 -// CHECK-NEXT: %{{.+}} = enzyme.dot %{{.+}}, %[[M_WARMUP]]#5 -// CHECK: enzyme.while_loop +// CHECK-NEXT: %[[M_SLOOP:.+]]:6 = impulse.for(%[[M_C0]] : tensor) to(%[[M_C1]] : tensor) step(%[[M_C1]] : tensor) iter_args(%[[M_WARMUP]]#0, %[[M_WARMUP]]#1, %[[M_WARMUP]]#2, %[[M_WARMUP]]#3, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> { +// CHECK: impulse.dot %{{.+}}, %[[M_WARMUP]]#6 +// CHECK-NEXT: %{{.+}} = impulse.dot %{{.+}}, %[[M_WARMUP]]#5 +// CHECK: impulse.while // CHECK: return %[[M_SLOOP]]#4, %[[M_SLOOP]]#5, %[[M_SLOOP]]#3 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> // CHECK-NEXT: } // ============================================================ // warmup_none: adapt_step_size=false, adapt_mass_matrix=false -// 13 iter_args, trivial enzyme.if (both branches identical) +// 13 iter_args, trivial impulse.if (both branches identical) // ============================================================ // CHECK-LABEL: func.func @warmup_none // CHECK-SAME: (%{{.+}}: tensor<2xui64>, %{{.+}}: tensor, %{{.+}}: tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>) @@ -485,31 +485,31 @@ module { // CHECK-DAG: %[[N_ONE_1D:.+]] = arith.constant dense<1.000000e+00> : tensor<1x1xf64> // // --- Init --- -// CHECK: enzyme.randomSplit +// CHECK: impulse.randomSplit // CHECK: arith.negf // CHECK: enzyme.autodiff_region // CHECK: } attributes // // --- Warmup loop: 13 iter_args --- -// CHECK-NEXT: %[[N_WARMUP:.+]]:13 = enzyme.for_loop(%[[N_C0]] : tensor) to(%[[N_C10]] : tensor) step(%[[N_C1]] : tensor) iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[N_ONE_1D]], %[[N_ONE_1D]], %{{.+}}, %{{.+}}, %{{.+}}, %[[N_C0]], %[[N_EPS_INIT]], %[[N_C0]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor { +// CHECK-NEXT: %[[N_WARMUP:.+]]:13 = impulse.for(%[[N_C0]] : tensor) to(%[[N_C10]] : tensor) step(%[[N_C1]] : tensor) iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[N_ONE_1D]], %[[N_ONE_1D]], %{{.+}}, %{{.+}}, %{{.+}}, %[[N_C0]], %[[N_EPS_INIT]], %[[N_C0]] : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor { // CHECK-NEXT: ^bb0(%[[N_WI:.+]]: tensor, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor<2xui64>, %{{.+}}: tensor, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor, %{{.+}}: tensor): // // --- Momentum sampling --- -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}} = enzyme.randomSplit -// CHECK-NEXT: %{{.+}}, %{{.+}} = enzyme.random {{.*}} {rng_distribution = #enzyme} -// CHECK-NEXT: %{{.+}} = enzyme.dot -// CHECK-NEXT: %{{.+}} = enzyme.dot -// CHECK-NEXT: %{{.+}} = enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}} = impulse.randomSplit +// CHECK-NEXT: %{{.+}}, %{{.+}} = impulse.random {{.*}} {rng_distribution = #impulse} +// CHECK-NEXT: %{{.+}} = impulse.dot +// CHECK-NEXT: %{{.+}} = impulse.dot +// CHECK-NEXT: %{{.+}} = impulse.dot {{.*}} lhs_contracting_dimensions = array // CHECK-NEXT: %{{.+}} = arith.mulf // CHECK-NEXT: %{{.+}} = arith.addf // // --- NUTS tree building --- -// CHECK-NEXT: %{{.+}}:18 = enzyme.while_loop +// CHECK-NEXT: %{{.+}}:18 = impulse.while // // --- Step size: trivial select (no dual averaging) --- // CHECK: %[[N_IS_END:.+]] = arith.cmpi eq, %[[N_WI]], %[[N_C9]] -// CHECK-NEXT: %{{.+}} = enzyme.select %[[N_IS_END]] +// CHECK-NEXT: %{{.+}} = impulse.select %[[N_IS_END]] // CHECK-NEXT: %{{.+}} = arith.maximumf %{{.+}}, %[[N_FMIN]] // CHECK-NEXT: %{{.+}} = arith.minimumf %{{.+}}, %[[N_FMAX]] // @@ -521,22 +521,22 @@ module { // CHECK-NEXT: %{{.+}} = arith.cmpi eq // CHECK-NEXT: %{{.+}} = arith.andi // CHECK-NEXT: %{{.+}} = arith.addi -// CHECK-NEXT: %{{.+}} = enzyme.select +// CHECK-NEXT: %{{.+}} = impulse.select // CHECK-NEXT: %{{.+}} = arith.andi // -// --- Trivial enzyme.if: both branches yield identical values --- -// CHECK-NEXT: %{{.+}}:7 = enzyme.if -// CHECK-NEXT: enzyme.yield +// --- Trivial impulse.if: both branches yield identical values --- +// CHECK-NEXT: %{{.+}}:7 = impulse.if +// CHECK-NEXT: impulse.yield // CHECK-NEXT: }, { -// CHECK-NEXT: enzyme.yield +// CHECK-NEXT: impulse.yield // CHECK-NEXT: }) // // --- Warmup yield: 13 values --- -// CHECK-NEXT: enzyme.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor +// CHECK-NEXT: impulse.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor // CHECK-NEXT: } // // --- Post-warmup sampling --- -// CHECK-NEXT: %[[N_SLOOP:.+]]:6 = enzyme.for_loop(%[[N_C0]] : tensor) to(%[[N_C1]] : tensor) step(%[[N_C1]] : tensor) iter_args(%[[N_WARMUP]]#0, %[[N_WARMUP]]#1, %[[N_WARMUP]]#2, %[[N_WARMUP]]#3, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> { -// CHECK: enzyme.while_loop +// CHECK-NEXT: %[[N_SLOOP:.+]]:6 = impulse.for(%[[N_C0]] : tensor) to(%[[N_C1]] : tensor) step(%[[N_C1]] : tensor) iter_args(%[[N_WARMUP]]#0, %[[N_WARMUP]]#1, %[[N_WARMUP]]#2, %[[N_WARMUP]]#3, %{{.+}}, %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1>) -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> { +// CHECK: impulse.while // CHECK: return %[[N_SLOOP]]#4, %[[N_SLOOP]]#5, %[[N_SLOOP]]#3 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> // CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ProbProg/mh.mlir b/enzyme/test/MLIR/ProbProg/mh.mlir index 0e45f64d73d7..e33b21282663 100644 --- a/enzyme/test/MLIR/ProbProg/mh.mlir +++ b/enzyme/test/MLIR/ProbProg/mh.mlir @@ -5,8 +5,8 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @test(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @normal(%s#0, %s#1, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<2>, name="t" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @normal(%s#0, %s#1, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<2>, name="t" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %t#0, %t#1 : tensor<2xui64>, tensor } @@ -17,11 +17,11 @@ module { %c1 = arith.constant 1 : index %c1000 = arith.constant 1000 : index %res:3 = scf.for %i = %c0 to %c1000 step %c1 iter_args(%trace = %init_trace, %weight = %init_weight, %rng1 = %rng) -> (tensor<1x2xf64>, tensor, tensor<2xui64>) { - %step1:4 = enzyme.mh @test(%rng1, %mean, %stddev) given %trace weight %weight - { name = "mh_1", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], regenerate_addresses = [[#enzyme.symbol<2>]] } + %step1:4 = impulse.mh @test(%rng1, %mean, %stddev) given %trace weight %weight + { name = "mh_1", selection = [[#impulse.symbol<1>], [#impulse.symbol<2>]], regenerate_addresses = [[#impulse.symbol<2>]] } : (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor, tensor) -> (tensor<1x2xf64>, tensor, tensor, tensor<2xui64>) - %step2:4 = enzyme.mh @test(%step1#3, %mean, %stddev) given %step1#0 weight %step1#1 - { name = "mh_2", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], regenerate_addresses = [[#enzyme.symbol<1>]] } + %step2:4 = impulse.mh @test(%step1#3, %mean, %stddev) given %step1#0 weight %step1#1 + { name = "mh_2", selection = [[#impulse.symbol<1>], [#impulse.symbol<2>]], regenerate_addresses = [[#impulse.symbol<1>]] } : (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor, tensor) -> (tensor<1x2xf64>, tensor, tensor, tensor<2xui64>) scf.yield %step2#0, %step2#1, %step2#3 : tensor<1x2xf64>, tensor, tensor<2xui64> } @@ -41,17 +41,17 @@ module { // CHECK: %[[LOOP:.+]]:3 = scf.for %[[IV:.+]] = %[[C0]] to %[[C1000]] step %[[C1]] iter_args(%[[TR:.+]] = %[[INIT_TRACE]], %[[WT:.+]] = %[[INIT_WEIGHT]], %[[RNG0:.+]] = %[[RNG]]) -> (tensor<1x2xf64>, tensor, tensor<2xui64>) // CHECK-NEXT: %[[REGEN1:.+]]:4 = func.call @test.regenerate_0(%[[TR]], %[[RNG0]], %[[MEAN]], %[[STDDEV]]) : (tensor<1x2xf64>, tensor<2xui64>, tensor, tensor) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor) // CHECK-NEXT: %[[WDIFF1:.+]] = arith.subf %[[REGEN1]]#1, %[[WT]] : tensor -// CHECK-NEXT: %[[RNG1:.+]], %[[U1:.+]] = enzyme.random %[[REGEN1]]#2, %[[ZERO_F]], %[[ONE]] {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) +// CHECK-NEXT: %[[RNG1:.+]], %[[U1:.+]] = impulse.random %[[REGEN1]]#2, %[[ZERO_F]], %[[ONE]] {rng_distribution = #impulse} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) // CHECK-NEXT: %[[LOG1:.+]] = math.log %[[U1]] : tensor // CHECK-NEXT: %[[ACC1:.+]] = arith.cmpf olt, %[[LOG1]], %[[WDIFF1]] : tensor -// CHECK-NEXT: %[[SEL_TR1:.+]] = enzyme.select %[[ACC1]], %[[REGEN1]]#0, %[[TR]] : (tensor, tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<1x2xf64> +// CHECK-NEXT: %[[SEL_TR1:.+]] = impulse.select %[[ACC1]], %[[REGEN1]]#0, %[[TR]] : (tensor, tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: %[[SEL_WT1:.+]] = arith.select %[[ACC1]], %[[REGEN1]]#1, %[[WT]] : tensor, tensor // CHECK-NEXT: %[[REGEN2:.+]]:4 = func.call @test.regenerate(%[[SEL_TR1]], %[[RNG1]], %[[MEAN]], %[[STDDEV]]) : (tensor<1x2xf64>, tensor<2xui64>, tensor, tensor) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor) // CHECK-NEXT: %[[WDIFF2:.+]] = arith.subf %[[REGEN2]]#1, %[[SEL_WT1]] : tensor -// CHECK-NEXT: %[[RNG2:.+]], %[[U2:.+]] = enzyme.random %[[REGEN2]]#2, %[[ZERO_F]], %[[ONE]] {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) +// CHECK-NEXT: %[[RNG2:.+]], %[[U2:.+]] = impulse.random %[[REGEN2]]#2, %[[ZERO_F]], %[[ONE]] {rng_distribution = #impulse} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) // CHECK-NEXT: %[[LOG2:.+]] = math.log %[[U2]] : tensor // CHECK-NEXT: %[[ACC2:.+]] = arith.cmpf olt, %[[LOG2]], %[[WDIFF2]] : tensor -// CHECK-NEXT: %[[SEL_TR2:.+]] = enzyme.select %[[ACC2]], %[[REGEN2]]#0, %[[SEL_TR1]] : (tensor, tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<1x2xf64> +// CHECK-NEXT: %[[SEL_TR2:.+]] = impulse.select %[[ACC2]], %[[REGEN2]]#0, %[[SEL_TR1]] : (tensor, tensor<1x2xf64>, tensor<1x2xf64>) -> tensor<1x2xf64> // CHECK-NEXT: %[[SEL_WT2:.+]] = arith.select %[[ACC2]], %[[REGEN2]]#1, %[[SEL_WT1]] : tensor, tensor // CHECK-NEXT: scf.yield %[[SEL_TR2]], %[[SEL_WT2]], %[[RNG2]] : tensor<1x2xf64>, tensor, tensor<2xui64> // CHECK-NEXT: } @@ -66,14 +66,14 @@ module { // CHECK: %[[R_REGEN:.+]]:2 = call @normal(%[[R_ARG1]], %[[R_ARG2]], %[[R_ARG3]]) : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) // CHECK-NEXT: %[[R_LP1:.+]] = call @logpdf(%[[R_REGEN]]#1, %[[R_ARG2]], %[[R_ARG3]]) : (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[R_W1:.+]] = arith.addf %[[R_LP1]], %[[R_ZERO]] : tensor -// CHECK-NEXT: %[[R_RS1:.+]] = enzyme.reshape %[[R_REGEN]]#1 : (tensor) -> tensor<1x1xf64> -// CHECK-NEXT: %[[R_TR1:.+]] = enzyme.dynamic_update_slice %[[R_TRACE_INIT]], %[[R_RS1]], %[[R_C0]], %[[R_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> -// CHECK-NEXT: %[[R_KEPT_SLICED:.+]] = enzyme.slice %[[R_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> -// CHECK-NEXT: %[[R_KEPT:.+]] = enzyme.reshape %[[R_KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor +// CHECK-NEXT: %[[R_RS1:.+]] = impulse.reshape %[[R_REGEN]]#1 : (tensor) -> tensor<1x1xf64> +// CHECK-NEXT: %[[R_TR1:.+]] = impulse.dynamic_update_slice %[[R_TRACE_INIT]], %[[R_RS1]], %[[R_C0]], %[[R_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// CHECK-NEXT: %[[R_KEPT_SLICED:.+]] = impulse.slice %[[R_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> +// CHECK-NEXT: %[[R_KEPT:.+]] = impulse.reshape %[[R_KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor // CHECK-NEXT: %[[R_LP2:.+]] = call @logpdf(%[[R_KEPT]], %[[R_REGEN]]#1, %[[R_ARG3]]) : (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[R_W2:.+]] = arith.addf %[[R_W1]], %[[R_LP2]] : tensor -// CHECK-NEXT: %[[R_RS2:.+]] = enzyme.reshape %[[R_KEPT]] : (tensor) -> tensor<1x1xf64> -// CHECK-NEXT: %[[R_TR2:.+]] = enzyme.dynamic_update_slice %[[R_TR1]], %[[R_RS2]], %[[R_C0]], %[[R_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// CHECK-NEXT: %[[R_RS2:.+]] = impulse.reshape %[[R_KEPT]] : (tensor) -> tensor<1x1xf64> +// CHECK-NEXT: %[[R_TR2:.+]] = impulse.dynamic_update_slice %[[R_TR1]], %[[R_RS2]], %[[R_C0]], %[[R_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // CHECK-NEXT: return %[[R_TR2]], %[[R_W2]], %[[R_REGEN]]#0, %[[R_KEPT]] : tensor<1x2xf64>, tensor, tensor<2xui64>, tensor // CHECK-LABEL: func.func @test.regenerate_0 @@ -82,15 +82,15 @@ module { // CHECK-DAG: %[[R0_C0:.+]] = arith.constant dense<0> : tensor // CHECK-DAG: %[[R0_ZERO:.+]] = arith.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[R0_TRACE_INIT:.+]] = arith.constant dense<0.000000e+00> : tensor<1x2xf64> -// CHECK: %[[R0_KEPT_SLICED:.+]] = enzyme.slice %[[R0_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> -// CHECK-NEXT: %[[R0_KEPT:.+]] = enzyme.reshape %[[R0_KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor +// CHECK: %[[R0_KEPT_SLICED:.+]] = impulse.slice %[[R0_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> +// CHECK-NEXT: %[[R0_KEPT:.+]] = impulse.reshape %[[R0_KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor // CHECK-NEXT: %[[R0_LP1:.+]] = call @logpdf(%[[R0_KEPT]], %[[R0_ARG2]], %[[R0_ARG3]]) : (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[R0_W1:.+]] = arith.addf %[[R0_LP1]], %[[R0_ZERO]] : tensor -// CHECK-NEXT: %[[R0_RS1:.+]] = enzyme.reshape %[[R0_KEPT]] : (tensor) -> tensor<1x1xf64> -// CHECK-NEXT: %[[R0_TR1:.+]] = enzyme.dynamic_update_slice %[[R0_TRACE_INIT]], %[[R0_RS1]], %[[R0_C0]], %[[R0_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// CHECK-NEXT: %[[R0_RS1:.+]] = impulse.reshape %[[R0_KEPT]] : (tensor) -> tensor<1x1xf64> +// CHECK-NEXT: %[[R0_TR1:.+]] = impulse.dynamic_update_slice %[[R0_TRACE_INIT]], %[[R0_RS1]], %[[R0_C0]], %[[R0_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // CHECK-NEXT: %[[R0_REGEN:.+]]:2 = call @normal(%[[R0_ARG1]], %[[R0_KEPT]], %[[R0_ARG3]]) : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) // CHECK-NEXT: %[[R0_LP2:.+]] = call @logpdf(%[[R0_REGEN]]#1, %[[R0_KEPT]], %[[R0_ARG3]]) : (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[R0_W2:.+]] = arith.addf %[[R0_W1]], %[[R0_LP2]] : tensor -// CHECK-NEXT: %[[R0_RS2:.+]] = enzyme.reshape %[[R0_REGEN]]#1 : (tensor) -> tensor<1x1xf64> -// CHECK-NEXT: %[[R0_TR2:.+]] = enzyme.dynamic_update_slice %[[R0_TR1]], %[[R0_RS2]], %[[R0_C0]], %[[R0_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// CHECK-NEXT: %[[R0_RS2:.+]] = impulse.reshape %[[R0_REGEN]]#1 : (tensor) -> tensor<1x1xf64> +// CHECK-NEXT: %[[R0_TR2:.+]] = impulse.dynamic_update_slice %[[R0_TR1]], %[[R0_RS2]], %[[R0_C0]], %[[R0_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // CHECK-NEXT: return %[[R0_TR2]], %[[R0_W2]], %[[R0_REGEN]]#0, %[[R0_REGEN]]#1 : tensor<1x2xf64>, tensor, tensor<2xui64>, tensor diff --git a/enzyme/test/MLIR/ProbProg/nuts_kernel.mlir b/enzyme/test/MLIR/ProbProg/nuts_kernel.mlir index 63ff3455fc23..8604de184ccf 100644 --- a/enzyme/test/MLIR/ProbProg/nuts_kernel.mlir +++ b/enzyme/test/MLIR/ProbProg/nuts_kernel.mlir @@ -5,17 +5,17 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @test(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %s#0, %s#1 : tensor<2xui64>, tensor } func.func @nuts(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>) { %init_trace = arith.constant dense<[[0.0]]> : tensor<1x1xf64> %step_size = arith.constant dense<0.1> : tensor - %res:8 = enzyme.mcmc @test(%rng, %mean, %stddev) given %init_trace + %res:8 = impulse.infer @test(%rng, %mean, %stddev) given %init_trace step_size = %step_size - { nuts_config = #enzyme.nuts_config, - name = "nuts", selection = [[#enzyme.symbol<1>]], all_addresses = [[#enzyme.symbol<1>]], num_warmup = 0, num_samples = 1 } + { nuts_config = #impulse.nuts_config, + name = "nuts", selection = [[#impulse.symbol<1>]], all_addresses = [[#impulse.symbol<1>]], num_warmup = 0, num_samples = 1 } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>, tensor) -> (tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor<1x1xf64>) return %res#0, %res#1, %res#2 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> } @@ -36,8 +36,8 @@ module { // CHECK-DAG: %[[MAX_DE:.+]] = arith.constant dense<1.000000e+03> : tensor // // --- RNG splits --- -// CHECK: enzyme.randomSplit -// CHECK: enzyme.randomSplit +// CHECK: impulse.randomSplit +// CHECK: impulse.randomSplit // // --- Initial gradient via autodiff --- // CHECK: enzyme.autodiff_region(%{{.+}}, %[[ONE]]) { @@ -48,33 +48,33 @@ module { // CHECK: } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} // // --- Sampling loop --- -// CHECK: %[[SLOOP:.+]]:6 = enzyme.for_loop(%[[C0]] : tensor) to(%[[C1]] : tensor) +// CHECK: %[[SLOOP:.+]]:6 = impulse.for(%[[C0]] : tensor) to(%[[C1]] : tensor) // CHECK-SAME: iter_args(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[INIT_TRACE]], %{{.+}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1>) // CHECK-SAME: -> tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> // CHECK: ^bb0(%[[S_ITER:.+]]: tensor, %[[S_Q:.+]]: tensor<1x1xf64>, %[[S_GRAD:.+]]: tensor<1x1xf64>, %{{.+}}: tensor, %{{.+}}: tensor<2xui64>, %{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<1xi1>): // // --- Momentum sampling --- -// CHECK: enzyme.random {{.*}} {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x1xf64>) +// CHECK: impulse.random {{.*}} {rng_distribution = #impulse} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<1x1xf64>) // // --- Kinetic energy --- -// CHECK: enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK: impulse.dot {{.*}} lhs_contracting_dimensions = array // // ============================================================ // Main NUTS tree building loop (outer while) // ============================================================ -// CHECK: %[[TREE:.+]]:18 = enzyme.while_loop({{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<1x1xf64>, tensor<2xui64>) -> {{.*}} condition { +// CHECK: %[[TREE:.+]]:18 = impulse.while({{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<1x1xf64>, tensor<2xui64>) -> {{.*}} condition { // // --- Condition: depth < max_tree_depth && !turning && !diverging --- // CHECK: arith.cmpi slt, %{{.+}}, %[[C3]] : tensor // CHECK: arith.xori {{.*}} : tensor // CHECK: arith.andi {{.*}} : tensor -// CHECK: enzyme.yield +// CHECK: impulse.yield // CHECK: } body { // // --- Direction sampling --- -// CHECK: enzyme.random {{.*}} {rng_distribution = #enzyme} +// CHECK: impulse.random {{.*}} {rng_distribution = #impulse} // CHECK: arith.cmpf olt, {{.*}} : tensor -// CHECK: enzyme.randomSplit {{.*}} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>) +// CHECK: impulse.randomSplit {{.*}} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>) // // --- Subtree size: 2^depth --- // CHECK: arith.shli {{.*}}, %{{.+}} : tensor @@ -82,10 +82,10 @@ module { // ============================================================ // Inner subtree building loop --- // ============================================================ -// CHECK: enzyme.while_loop({{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<1x1xf64>, tensor<2xui64>, tensor<3x1xf64>, tensor<3x1xf64>, tensor) -> {{.*}} condition { +// CHECK: impulse.while({{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<1x1xf64>, tensor<2xui64>, tensor<3x1xf64>, tensor<3x1xf64>, tensor) -> {{.*}} condition { // CHECK: arith.cmpi slt, {{.*}} : tensor // CHECK: arith.andi {{.*}} : tensor -// CHECK: enzyme.yield +// CHECK: impulse.yield // CHECK: } body { // // --- Leapfrog step --- @@ -104,7 +104,7 @@ module { // CHECK: arith.subf {{.*}} : tensor<1x1xf64> // // --- Kinetic energy --- -// CHECK: enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK: impulse.dot {{.*}} lhs_contracting_dimensions = array // // --- Delta energy and divergence check --- // CHECK: arith.subf {{.*}} : tensor @@ -112,74 +112,74 @@ module { // // --- Tree combination --- // CHECK: arith.cmpi eq, %{{.+}}, %[[C0]] : tensor -// CHECK: enzyme.if -// CHECK: enzyme.yield +// CHECK: impulse.if +// CHECK: impulse.yield // CHECK: }, { -// CHECK: enzyme.log_add_exp -// CHECK: enzyme.logistic -// CHECK: enzyme.random -// CHECK: enzyme.select -// CHECK: enzyme.yield +// CHECK: impulse.log_add_exp +// CHECK: impulse.logistic +// CHECK: impulse.random +// CHECK: impulse.select +// CHECK: impulse.yield // CHECK: }) // // --- Checkpoint updates --- -// CHECK: enzyme.popcount -// CHECK: enzyme.dynamic_update_slice {{.*}} : (tensor<3x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<3x1xf64> +// CHECK: impulse.popcount +// CHECK: impulse.dynamic_update_slice {{.*}} : (tensor<3x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<3x1xf64> // // --- Iterative turning check loop --- -// CHECK: enzyme.while_loop({{.*}} : tensor, tensor) -> tensor, tensor condition { -// CHECK: enzyme.yield +// CHECK: impulse.while({{.*}} : tensor, tensor) -> tensor, tensor condition { +// CHECK: impulse.yield // CHECK: } body { -// CHECK: enzyme.dynamic_slice {{.*}} {slice_sizes = array} +// CHECK: impulse.dynamic_slice {{.*}} {slice_sizes = array} // --- Dynamic termination criterion --- -// CHECK: enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK: impulse.dot {{.*}} lhs_contracting_dimensions = array // CHECK: arith.cmpf ole, {{.*}} : tensor // CHECK: arith.ori {{.*}} : tensor -// CHECK: enzyme.yield +// CHECK: impulse.yield // CHECK: } // // --- Subtree yield --- -// CHECK: enzyme.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<1x1xf64>, tensor<2xui64>, tensor<3x1xf64>, tensor<3x1xf64>, tensor +// CHECK: impulse.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<1x1xf64>, tensor<2xui64>, tensor<3x1xf64>, tensor<3x1xf64>, tensor // CHECK: } // // ============================================================ // Tree combination with biased kernel // ============================================================ // --- Update left/right boundaries --- -// CHECK: enzyme.select {{.*}} : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) -// CHECK: enzyme.select {{.*}} : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) +// CHECK: impulse.select {{.*}} : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) +// CHECK: impulse.select {{.*}} : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) // // --- Biased transition: exp, min --- -// CHECK: enzyme.log_add_exp +// CHECK: impulse.log_add_exp // CHECK: math.exp // CHECK: arith.minimumf {{.*}}, %[[ONE]] // // --- Zero probability when turning/diverging --- // CHECK: arith.ori {{.*}} : tensor // CHECK: arith.select {{.*}}, %[[ZERO_F]] -// CHECK: enzyme.random +// CHECK: impulse.random // CHECK: arith.cmpf olt -// CHECK: enzyme.select {{.*}} : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) +// CHECK: impulse.select {{.*}} : (tensor, tensor<1x1xf64>, tensor<1x1xf64>) // // --- Turning check on combined tree --- // CHECK: arith.addf {{.*}} : tensor<1x1xf64> // CHECK: arith.mulf {{.*}} : tensor<1x1xf64> // CHECK: arith.subf {{.*}} : tensor<1x1xf64> -// CHECK: enzyme.dot {{.*}} lhs_contracting_dimensions = array +// CHECK: impulse.dot {{.*}} lhs_contracting_dimensions = array // CHECK: arith.cmpf ole, {{.*}} : tensor // CHECK: arith.ori {{.*}} : tensor // // --- Outer loop yield --- -// CHECK: enzyme.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<1x1xf64>, tensor<2xui64> +// CHECK: impulse.yield {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<1x1xf64>, tensor<2xui64> // CHECK: } // // --- Store sample --- // CHECK: arith.cmpi sge, %[[S_ITER]], %[[C0]] -// CHECK: enzyme.dynamic_update_slice {{.*}} : (tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x1xf64> -// CHECK: enzyme.select +// CHECK: impulse.dynamic_update_slice {{.*}} : (tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x1xf64> +// CHECK: impulse.select // // --- Sampling loop yield --- -// CHECK: enzyme.yield %[[TREE]]#6, %[[TREE]]#7, %[[TREE]]#8, {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> +// CHECK: impulse.yield %[[TREE]]#6, %[[TREE]]#7, %[[TREE]]#8, {{.*}} : tensor<1x1xf64>, tensor<1x1xf64>, tensor, tensor<2xui64>, tensor<1x1xf64>, tensor<1xi1> // CHECK: } // CHECK: return %[[SLOOP]]#4, %[[SLOOP]]#5, %[[SLOOP]]#3 : tensor<1x1xf64>, tensor<1xi1>, tensor<2xui64> // CHECK: } @@ -187,6 +187,6 @@ module { // --- Generated function: test.generate --- // CHECK-LABEL: func.func @test.generate // CHECK-SAME: (%{{.+}}: tensor<1x1xf64>, %{{.+}}: tensor<2xui64>, %{{.+}}: tensor, %{{.+}}: tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) -// CHECK: enzyme.slice %{{.+}} {limit_indices = array, start_indices = array +// CHECK: impulse.slice %{{.+}} {limit_indices = array, start_indices = array // CHECK: call @logpdf // CHECK: return {{.*}} : tensor<1x1xf64>, tensor, tensor<2xui64>, tensor diff --git a/enzyme/test/MLIR/ProbProg/regenerate.mlir b/enzyme/test/MLIR/ProbProg/regenerate.mlir index de8857e530e0..4ca7f510f46e 100644 --- a/enzyme/test/MLIR/ProbProg/regenerate.mlir +++ b/enzyme/test/MLIR/ProbProg/regenerate.mlir @@ -6,15 +6,15 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @model(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @normal(%s#0, %s#1, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @normal(%s#0, %s#1, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %t#0, %t#1 : tensor<2xui64>, tensor } func.func @test_base(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor) { %prev_trace = arith.constant dense<[[0.5, 1.0]]> : tensor<1x2xf64> - %res:4 = enzyme.regenerate @model(%rng, %mean, %stddev) given %prev_trace - { selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]], regenerate_addresses = [[#enzyme.symbol<2>]] } + %res:4 = impulse.regenerate @model(%rng, %mean, %stddev) given %prev_trace + { selection = [[#impulse.symbol<1>], [#impulse.symbol<2>]], regenerate_addresses = [[#impulse.symbol<2>]] } : (tensor<2xui64>, tensor, tensor, tensor<1x2xf64>) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor) return %res#0, %res#1, %res#2, %res#3 : tensor<1x2xf64>, tensor, tensor<2xui64>, tensor } @@ -32,17 +32,17 @@ module { // BASE-DAG: %[[C0:.+]] = arith.constant dense<0> : tensor // BASE-DAG: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : tensor // BASE-DAG: %[[TRACE_INIT:.+]] = arith.constant dense<0.000000e+00> : tensor<1x2xf64> -// BASE: %[[KEPT_SLICED:.+]] = enzyme.slice %[[ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> -// BASE-NEXT: %[[KEPT:.+]] = enzyme.reshape %[[KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor +// BASE: %[[KEPT_SLICED:.+]] = impulse.slice %[[ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> +// BASE-NEXT: %[[KEPT:.+]] = impulse.reshape %[[KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor // BASE-NEXT: %[[LP1:.+]] = call @logpdf(%[[KEPT]], %[[ARG2]], %[[ARG3]]) // BASE-NEXT: %[[W1:.+]] = arith.addf %[[LP1]], %[[ZERO]] : tensor -// BASE-NEXT: %[[RS1:.+]] = enzyme.reshape %[[KEPT]] : (tensor) -> tensor<1x1xf64> -// BASE-NEXT: %[[TR1:.+]] = enzyme.dynamic_update_slice %[[TRACE_INIT]], %[[RS1]], %[[C0]], %[[C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// BASE-NEXT: %[[RS1:.+]] = impulse.reshape %[[KEPT]] : (tensor) -> tensor<1x1xf64> +// BASE-NEXT: %[[TR1:.+]] = impulse.dynamic_update_slice %[[TRACE_INIT]], %[[RS1]], %[[C0]], %[[C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // BASE-NEXT: %[[REGEN:.+]]:2 = call @normal(%[[ARG1]], %[[KEPT]], %[[ARG3]]) // BASE-NEXT: %[[LP2:.+]] = call @logpdf(%[[REGEN]]#1, %[[KEPT]], %[[ARG3]]) // BASE-NEXT: %[[W2:.+]] = arith.addf %[[W1]], %[[LP2]] : tensor -// BASE-NEXT: %[[RS2:.+]] = enzyme.reshape %[[REGEN]]#1 : (tensor) -> tensor<1x1xf64> -// BASE-NEXT: %[[TR2:.+]] = enzyme.dynamic_update_slice %[[TR1]], %[[RS2]], %[[C0]], %[[C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// BASE-NEXT: %[[RS2:.+]] = impulse.reshape %[[REGEN]]#1 : (tensor) -> tensor<1x1xf64> +// BASE-NEXT: %[[TR2:.+]] = impulse.dynamic_update_slice %[[TR1]], %[[RS2]], %[[C0]], %[[C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // BASE-NEXT: return %[[TR2]], %[[W2]], %[[REGEN]]#0, %[[REGEN]]#1 // ----- @@ -52,22 +52,22 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @inner(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<3> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @normal(%s#0, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<4> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<3> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @normal(%s#0, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<4> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %t#0, %s#1, %t#1 : tensor<2xui64>, tensor, tensor } func.func @outer(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:3 = enzyme.sample @inner(%s#0, %s#1, %stddev) { symbol = #enzyme.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:3 = impulse.sample @inner(%s#0, %s#1, %stddev) { symbol = #impulse.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor, tensor) return %t#0, %t#1, %t#2 : tensor<2xui64>, tensor, tensor } func.func @test_hier(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x3xf64>, tensor, tensor<2xui64>, tensor, tensor) { %prev_trace = arith.constant dense<[[0.5, 1.0, 2.0]]> : tensor<1x3xf64> - %res:5 = enzyme.regenerate @outer(%rng, %mean, %stddev) given %prev_trace - { selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>, #enzyme.symbol<3>], [#enzyme.symbol<2>, #enzyme.symbol<4>]], - regenerate_addresses = [[#enzyme.symbol<2>, #enzyme.symbol<4>]] } + %res:5 = impulse.regenerate @outer(%rng, %mean, %stddev) given %prev_trace + { selection = [[#impulse.symbol<1>], [#impulse.symbol<2>, #impulse.symbol<3>], [#impulse.symbol<2>, #impulse.symbol<4>]], + regenerate_addresses = [[#impulse.symbol<2>, #impulse.symbol<4>]] } : (tensor<2xui64>, tensor, tensor, tensor<1x3xf64>) -> (tensor<1x3xf64>, tensor, tensor<2xui64>, tensor, tensor) return %res#0, %res#1, %res#2, %res#3, %res#4 : tensor<1x3xf64>, tensor, tensor<2xui64>, tensor, tensor } @@ -85,16 +85,16 @@ module { // HIER-DAG: %[[O_C0:.+]] = arith.constant dense<0> : tensor // HIER-DAG: %[[O_ZERO:.+]] = arith.constant dense<0.000000e+00> : tensor // HIER-DAG: %[[O_TRACE_INIT:.+]] = arith.constant dense<0.000000e+00> : tensor<1x3xf64> -// HIER: %[[O_KEPT_SLICED:.+]] = enzyme.slice %[[O_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x3xf64>) -> tensor<1x1xf64> -// HIER-NEXT: %[[O_KEPT:.+]] = enzyme.reshape %[[O_KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor +// HIER: %[[O_KEPT_SLICED:.+]] = impulse.slice %[[O_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x3xf64>) -> tensor<1x1xf64> +// HIER-NEXT: %[[O_KEPT:.+]] = impulse.reshape %[[O_KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor // HIER-NEXT: %[[O_LP1:.+]] = call @logpdf(%[[O_KEPT]], %[[O_ARG2]], %[[O_ARG3]]) // HIER-NEXT: %[[O_W1:.+]] = arith.addf %[[O_LP1]], %[[O_ZERO]] : tensor -// HIER-NEXT: %[[O_RS1:.+]] = enzyme.reshape %[[O_KEPT]] : (tensor) -> tensor<1x1xf64> -// HIER-NEXT: %[[O_TR1:.+]] = enzyme.dynamic_update_slice %[[O_TRACE_INIT]], %[[O_RS1]], %[[O_C0]], %[[O_C0]] : (tensor<1x3xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x3xf64> -// HIER-NEXT: %[[O_SUB_TRACE:.+]] = enzyme.slice %[[O_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x3xf64>) -> tensor<1x2xf64> +// HIER-NEXT: %[[O_RS1:.+]] = impulse.reshape %[[O_KEPT]] : (tensor) -> tensor<1x1xf64> +// HIER-NEXT: %[[O_TR1:.+]] = impulse.dynamic_update_slice %[[O_TRACE_INIT]], %[[O_RS1]], %[[O_C0]], %[[O_C0]] : (tensor<1x3xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x3xf64> +// HIER-NEXT: %[[O_SUB_TRACE:.+]] = impulse.slice %[[O_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x3xf64>) -> tensor<1x2xf64> // HIER-NEXT: %[[O_NESTED:.+]]:5 = call @inner.regenerate(%[[O_SUB_TRACE]], %[[O_ARG1]], %[[O_KEPT]], %[[O_ARG3]]) : (tensor<1x2xf64>, tensor<2xui64>, tensor, tensor) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor, tensor) // HIER-NEXT: %[[O_W2:.+]] = arith.addf %[[O_W1]], %[[O_NESTED]]#1 : tensor -// HIER-NEXT: %[[O_TR2:.+]] = enzyme.dynamic_update_slice %[[O_TR1]], %[[O_NESTED]]#0, %[[O_C0]], %[[O_C1]] : (tensor<1x3xf64>, tensor<1x2xf64>, tensor, tensor) -> tensor<1x3xf64> +// HIER-NEXT: %[[O_TR2:.+]] = impulse.dynamic_update_slice %[[O_TR1]], %[[O_NESTED]]#0, %[[O_C0]], %[[O_C1]] : (tensor<1x3xf64>, tensor<1x2xf64>, tensor, tensor) -> tensor<1x3xf64> // HIER-NEXT: return %[[O_TR2]], %[[O_W2]], %[[O_NESTED]]#2, %[[O_NESTED]]#3, %[[O_NESTED]]#4 // HIER-LABEL: func.func @inner.regenerate @@ -103,15 +103,15 @@ module { // HIER-DAG: %[[I_C0:.+]] = arith.constant dense<0> : tensor // HIER-DAG: %[[I_ZERO:.+]] = arith.constant dense<0.000000e+00> : tensor // HIER-DAG: %[[I_TRACE_INIT:.+]] = arith.constant dense<0.000000e+00> : tensor<1x2xf64> -// HIER: %[[I_KEPT_SLICED:.+]] = enzyme.slice %[[I_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> -// HIER-NEXT: %[[I_KEPT:.+]] = enzyme.reshape %[[I_KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor +// HIER: %[[I_KEPT_SLICED:.+]] = impulse.slice %[[I_ARG0]] {limit_indices = array, start_indices = array, strides = array} : (tensor<1x2xf64>) -> tensor<1x1xf64> +// HIER-NEXT: %[[I_KEPT:.+]] = impulse.reshape %[[I_KEPT_SLICED]] : (tensor<1x1xf64>) -> tensor // HIER-NEXT: %[[I_LP1:.+]] = call @logpdf(%[[I_KEPT]], %[[I_ARG2]], %[[I_ARG3]]) // HIER-NEXT: %[[I_W1:.+]] = arith.addf %[[I_LP1]], %[[I_ZERO]] : tensor -// HIER-NEXT: %[[I_RS1:.+]] = enzyme.reshape %[[I_KEPT]] : (tensor) -> tensor<1x1xf64> -// HIER-NEXT: %[[I_TR1:.+]] = enzyme.dynamic_update_slice %[[I_TRACE_INIT]], %[[I_RS1]], %[[I_C0]], %[[I_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// HIER-NEXT: %[[I_RS1:.+]] = impulse.reshape %[[I_KEPT]] : (tensor) -> tensor<1x1xf64> +// HIER-NEXT: %[[I_TR1:.+]] = impulse.dynamic_update_slice %[[I_TRACE_INIT]], %[[I_RS1]], %[[I_C0]], %[[I_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // HIER-NEXT: %[[I_REGEN:.+]]:2 = call @normal(%[[I_ARG1]], %[[I_ARG2]], %[[I_ARG3]]) // HIER-NEXT: %[[I_LP2:.+]] = call @logpdf(%[[I_REGEN]]#1, %[[I_ARG2]], %[[I_ARG3]]) // HIER-NEXT: %[[I_W2:.+]] = arith.addf %[[I_W1]], %[[I_LP2]] : tensor -// HIER-NEXT: %[[I_RS2:.+]] = enzyme.reshape %[[I_REGEN]]#1 : (tensor) -> tensor<1x1xf64> -// HIER-NEXT: %[[I_TR2:.+]] = enzyme.dynamic_update_slice %[[I_TR1]], %[[I_RS2]], %[[I_C0]], %[[I_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// HIER-NEXT: %[[I_RS2:.+]] = impulse.reshape %[[I_REGEN]]#1 : (tensor) -> tensor<1x1xf64> +// HIER-NEXT: %[[I_TR2:.+]] = impulse.dynamic_update_slice %[[I_TR1]], %[[I_RS2]], %[[I_C0]], %[[I_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // HIER-NEXT: return %[[I_TR2]], %[[I_W2]], %[[I_REGEN]]#0, %[[I_KEPT]], %[[I_REGEN]]#1 diff --git a/enzyme/test/MLIR/ProbProg/roundtrip.mlir b/enzyme/test/MLIR/ProbProg/roundtrip.mlir index 0f1be931c5b6..5a905f636325 100644 --- a/enzyme/test/MLIR/ProbProg/roundtrip.mlir +++ b/enzyme/test/MLIR/ProbProg/roundtrip.mlir @@ -5,31 +5,31 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor // CHECK: func.func @sample(%[[SEED:.+]]: tensor<2xui64>, %[[MEAN:.+]]: tensor, %[[STDDEV:.+]]: tensor) -> (tensor<2xui64>, tensor) { - // CHECK-NEXT: %[[RES:.+]]:2 = enzyme.sample @normal(%[[SEED]], %[[MEAN]], %[[STDDEV]]) {logpdf = @logpdf, name = "r", symbol = #enzyme.symbol<3>} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + // CHECK-NEXT: %[[RES:.+]]:2 = impulse.sample @normal(%[[SEED]], %[[MEAN]], %[[STDDEV]]) {logpdf = @logpdf, name = "r", symbol = #impulse.symbol<3>} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) // CHECK-NEXT: return %[[RES]]#0, %[[RES]]#1 : tensor<2xui64>, tensor // CHECK-NEXT: } func.func @sample(%seed : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %r:2 = enzyme.sample @normal(%seed, %mean, %stddev) { logpdf = @logpdf, name="r", symbol = #enzyme.symbol<3> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %r:2 = impulse.sample @normal(%seed, %mean, %stddev) { logpdf = @logpdf, name="r", symbol = #impulse.symbol<3> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %r#0, %r#1 : tensor<2xui64>, tensor } // CHECK: func.func @simulate(%[[SEED:.+]]: tensor<2xui64>, %[[MEAN:.+]]: tensor, %[[STDDEV:.+]]: tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) { - // CHECK-NEXT: %[[TRACE:.+]], %[[WEIGHT:.+]], %[[OUTPUTS:.+]]:2 = enzyme.simulate @sample(%[[SEED]], %[[MEAN]], %[[STDDEV]]) {name = "test", selection = {{\[}}[#enzyme.symbol<3>]]} : (tensor<2xui64>, tensor, tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) + // CHECK-NEXT: %[[TRACE:.+]], %[[WEIGHT:.+]], %[[OUTPUTS:.+]]:2 = impulse.simulate @sample(%[[SEED]], %[[MEAN]], %[[STDDEV]]) {name = "test", selection = {{\[}}[#impulse.symbol<3>]]} : (tensor<2xui64>, tensor, tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) // CHECK-NEXT: return %[[TRACE]], %[[WEIGHT]], %[[OUTPUTS]]#0, %[[OUTPUTS]]#1 : tensor<1x1xf64>, tensor, tensor<2xui64>, tensor // CHECK-NEXT: } func.func @simulate(%seed : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) { - %res:4 = enzyme.simulate @sample(%seed, %mean, %stddev) { selection = [[#enzyme.symbol<3>]], name = "test" } : (tensor<2xui64>, tensor, tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) + %res:4 = impulse.simulate @sample(%seed, %mean, %stddev) { selection = [[#impulse.symbol<3>]], name = "test" } : (tensor<2xui64>, tensor, tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) return %res#0, %res#1, %res#2, %res#3 : tensor<1x1xf64>, tensor, tensor<2xui64>, tensor } // CHECK: func.func @generate(%[[SEED:.+]]: tensor<2xui64>, %[[MEAN:.+]]: tensor, %[[STDDEV:.+]]: tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) { // CHECK-NEXT: %[[CONSTRAINT:.+]] = arith.constant dense<{{.*}}1.5{{.*}}> : tensor<1x1xf64> - // CHECK-NEXT: %[[TRACE:.+]], %[[WEIGHT:.+]], %[[OUTPUTS:.+]]:2 = enzyme.generate @sample(%[[SEED]], %[[MEAN]], %[[STDDEV]]) given %[[CONSTRAINT]] {constrained_addresses = {{\[}}[#enzyme.symbol<3>]], name = "test", selection = {{\[}}[#enzyme.symbol<3>]]} : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) + // CHECK-NEXT: %[[TRACE:.+]], %[[WEIGHT:.+]], %[[OUTPUTS:.+]]:2 = impulse.generate @sample(%[[SEED]], %[[MEAN]], %[[STDDEV]]) given %[[CONSTRAINT]] {constrained_addresses = {{\[}}[#impulse.symbol<3>]], name = "test", selection = {{\[}}[#impulse.symbol<3>]]} : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) // CHECK-NEXT: return %[[TRACE]], %[[WEIGHT]], %[[OUTPUTS]]#0, %[[OUTPUTS]]#1 : tensor<1x1xf64>, tensor, tensor<2xui64>, tensor // CHECK-NEXT: } func.func @generate(%seed : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) { %constraint = arith.constant dense<1.5> : tensor<1x1xf64> - %res:4 = enzyme.generate @sample(%seed, %mean, %stddev) given %constraint { selection = [[#enzyme.symbol<3>]], constrained_addresses = [[#enzyme.symbol<3>]], name = "test" } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) + %res:4 = impulse.generate @sample(%seed, %mean, %stddev) given %constraint { selection = [[#impulse.symbol<3>]], constrained_addresses = [[#impulse.symbol<3>]], name = "test" } : (tensor<2xui64>, tensor, tensor, tensor<1x1xf64>) -> (tensor<1x1xf64>, tensor, tensor<2xui64>, tensor) return %res#0, %res#1, %res#2, %res#3 : tensor<1x1xf64>, tensor, tensor<2xui64>, tensor } } diff --git a/enzyme/test/MLIR/ProbProg/simulate.mlir b/enzyme/test/MLIR/ProbProg/simulate.mlir index 7dd04c3138e9..d50284af32b2 100644 --- a/enzyme/test/MLIR/ProbProg/simulate.mlir +++ b/enzyme/test/MLIR/ProbProg/simulate.mlir @@ -6,14 +6,14 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @model(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @normal(%s#0, %s#1, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @normal(%s#0, %s#1, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %t#0, %t#1 : tensor<2xui64>, tensor } func.func @test_base(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor) { - %res:4 = enzyme.simulate @model(%rng, %mean, %stddev) - { selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]] } + %res:4 = impulse.simulate @model(%rng, %mean, %stddev) + { selection = [[#impulse.symbol<1>], [#impulse.symbol<2>]] } : (tensor<2xui64>, tensor, tensor) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor) return %res#0, %res#1, %res#2, %res#3 : tensor<1x2xf64>, tensor, tensor<2xui64>, tensor } @@ -33,13 +33,13 @@ module { // BASE: %[[S1:.+]]:2 = call @normal(%[[ARG0]], %[[ARG1]], %[[ARG2]]) // BASE-NEXT: %[[LP1:.+]] = call @logpdf(%[[S1]]#1, %[[ARG1]], %[[ARG2]]) // BASE-NEXT: %[[W1:.+]] = arith.addf %[[LP1]], %[[ZERO]] : tensor -// BASE-NEXT: %[[RS1:.+]] = enzyme.reshape %[[S1]]#1 : (tensor) -> tensor<1x1xf64> -// BASE-NEXT: %[[TR1:.+]] = enzyme.dynamic_update_slice %[[TRACE_INIT]], %[[RS1]], %[[C0]], %[[C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// BASE-NEXT: %[[RS1:.+]] = impulse.reshape %[[S1]]#1 : (tensor) -> tensor<1x1xf64> +// BASE-NEXT: %[[TR1:.+]] = impulse.dynamic_update_slice %[[TRACE_INIT]], %[[RS1]], %[[C0]], %[[C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // BASE-NEXT: %[[S2:.+]]:2 = call @normal(%[[S1]]#0, %[[S1]]#1, %[[ARG2]]) // BASE-NEXT: %[[LP2:.+]] = call @logpdf(%[[S2]]#1, %[[S1]]#1, %[[ARG2]]) // BASE-NEXT: %[[W2:.+]] = arith.addf %[[W1]], %[[LP2]] : tensor -// BASE-NEXT: %[[RS2:.+]] = enzyme.reshape %[[S2]]#1 : (tensor) -> tensor<1x1xf64> -// BASE-NEXT: %[[TR2:.+]] = enzyme.dynamic_update_slice %[[TR1]], %[[RS2]], %[[C0]], %[[C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// BASE-NEXT: %[[RS2:.+]] = impulse.reshape %[[S2]]#1 : (tensor) -> tensor<1x1xf64> +// BASE-NEXT: %[[TR2:.+]] = impulse.dynamic_update_slice %[[TR1]], %[[RS2]], %[[C0]], %[[C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // BASE-NEXT: return %[[TR2]], %[[W2]], %[[S2]]#0, %[[S2]]#1 // ----- @@ -49,20 +49,20 @@ module { func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor func.func @inner(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<3> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @normal(%s#0, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<4> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<3> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @normal(%s#0, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<4> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %t#0, %s#1, %t#1 : tensor<2xui64>, tensor, tensor } func.func @outer(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor, tensor) { - %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:3 = enzyme.sample @inner(%s#0, %s#1, %stddev) { symbol = #enzyme.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor, tensor) + %s:2 = impulse.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #impulse.symbol<1> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:3 = impulse.sample @inner(%s#0, %s#1, %stddev) { symbol = #impulse.symbol<2> } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor, tensor) return %t#0, %t#1, %t#2 : tensor<2xui64>, tensor, tensor } func.func @test_hier(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<1x3xf64>, tensor, tensor<2xui64>, tensor, tensor) { - %res:5 = enzyme.simulate @outer(%rng, %mean, %stddev) - { selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>, #enzyme.symbol<3>], [#enzyme.symbol<2>, #enzyme.symbol<4>]] } + %res:5 = impulse.simulate @outer(%rng, %mean, %stddev) + { selection = [[#impulse.symbol<1>], [#impulse.symbol<2>, #impulse.symbol<3>], [#impulse.symbol<2>, #impulse.symbol<4>]] } : (tensor<2xui64>, tensor, tensor) -> (tensor<1x3xf64>, tensor, tensor<2xui64>, tensor, tensor) return %res#0, %res#1, %res#2, %res#3, %res#4 : tensor<1x3xf64>, tensor, tensor<2xui64>, tensor, tensor } @@ -82,11 +82,11 @@ module { // HIER: %[[O_S1:.+]]:2 = call @normal(%[[O_ARG0]], %[[O_ARG1]], %[[O_ARG2]]) // HIER-NEXT: %[[O_LP1:.+]] = call @logpdf(%[[O_S1]]#1, %[[O_ARG1]], %[[O_ARG2]]) // HIER-NEXT: %[[O_W1:.+]] = arith.addf %[[O_LP1]], %[[O_ZERO]] : tensor -// HIER-NEXT: %[[O_RS1:.+]] = enzyme.reshape %[[O_S1]]#1 : (tensor) -> tensor<1x1xf64> -// HIER-NEXT: %[[O_TR1:.+]] = enzyme.dynamic_update_slice %[[O_TRACE_INIT]], %[[O_RS1]], %[[O_C0]], %[[O_C0]] : (tensor<1x3xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x3xf64> +// HIER-NEXT: %[[O_RS1:.+]] = impulse.reshape %[[O_S1]]#1 : (tensor) -> tensor<1x1xf64> +// HIER-NEXT: %[[O_TR1:.+]] = impulse.dynamic_update_slice %[[O_TRACE_INIT]], %[[O_RS1]], %[[O_C0]], %[[O_C0]] : (tensor<1x3xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x3xf64> // HIER-NEXT: %[[O_NESTED:.+]]:5 = call @inner.simulate(%[[O_S1]]#0, %[[O_S1]]#1, %[[O_ARG2]]) : (tensor<2xui64>, tensor, tensor) -> (tensor<1x2xf64>, tensor, tensor<2xui64>, tensor, tensor) // HIER-NEXT: %[[O_W2:.+]] = arith.addf %[[O_W1]], %[[O_NESTED]]#1 : tensor -// HIER-NEXT: %[[O_TR2:.+]] = enzyme.dynamic_update_slice %[[O_TR1]], %[[O_NESTED]]#0, %[[O_C0]], %[[O_C1]] : (tensor<1x3xf64>, tensor<1x2xf64>, tensor, tensor) -> tensor<1x3xf64> +// HIER-NEXT: %[[O_TR2:.+]] = impulse.dynamic_update_slice %[[O_TR1]], %[[O_NESTED]]#0, %[[O_C0]], %[[O_C1]] : (tensor<1x3xf64>, tensor<1x2xf64>, tensor, tensor) -> tensor<1x3xf64> // HIER-NEXT: return %[[O_TR2]], %[[O_W2]], %[[O_NESTED]]#2, %[[O_NESTED]]#3, %[[O_NESTED]]#4 // HIER-LABEL: func.func @inner.simulate @@ -98,11 +98,11 @@ module { // HIER: %[[I_S1:.+]]:2 = call @normal(%[[I_ARG0]], %[[I_ARG1]], %[[I_ARG2]]) // HIER-NEXT: %[[I_LP1:.+]] = call @logpdf(%[[I_S1]]#1, %[[I_ARG1]], %[[I_ARG2]]) // HIER-NEXT: %[[I_W1:.+]] = arith.addf %[[I_LP1]], %[[I_ZERO]] : tensor -// HIER-NEXT: %[[I_RS1:.+]] = enzyme.reshape %[[I_S1]]#1 : (tensor) -> tensor<1x1xf64> -// HIER-NEXT: %[[I_TR1:.+]] = enzyme.dynamic_update_slice %[[I_TRACE_INIT]], %[[I_RS1]], %[[I_C0]], %[[I_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// HIER-NEXT: %[[I_RS1:.+]] = impulse.reshape %[[I_S1]]#1 : (tensor) -> tensor<1x1xf64> +// HIER-NEXT: %[[I_TR1:.+]] = impulse.dynamic_update_slice %[[I_TRACE_INIT]], %[[I_RS1]], %[[I_C0]], %[[I_C0]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // HIER-NEXT: %[[I_S2:.+]]:2 = call @normal(%[[I_S1]]#0, %[[I_ARG1]], %[[I_ARG2]]) // HIER-NEXT: %[[I_LP2:.+]] = call @logpdf(%[[I_S2]]#1, %[[I_ARG1]], %[[I_ARG2]]) // HIER-NEXT: %[[I_W2:.+]] = arith.addf %[[I_W1]], %[[I_LP2]] : tensor -// HIER-NEXT: %[[I_RS2:.+]] = enzyme.reshape %[[I_S2]]#1 : (tensor) -> tensor<1x1xf64> -// HIER-NEXT: %[[I_TR2:.+]] = enzyme.dynamic_update_slice %[[I_TR1]], %[[I_RS2]], %[[I_C0]], %[[I_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> +// HIER-NEXT: %[[I_RS2:.+]] = impulse.reshape %[[I_S2]]#1 : (tensor) -> tensor<1x1xf64> +// HIER-NEXT: %[[I_TR2:.+]] = impulse.dynamic_update_slice %[[I_TR1]], %[[I_RS2]], %[[I_C0]], %[[I_C1]] : (tensor<1x2xf64>, tensor<1x1xf64>, tensor, tensor) -> tensor<1x2xf64> // HIER-NEXT: return %[[I_TR2]], %[[I_W2]], %[[I_S2]]#0, %[[I_S1]]#1, %[[I_S2]]#1 diff --git a/enzyme/test/MLIR/ProbProg/untraced_call.mlir b/enzyme/test/MLIR/ProbProg/untraced_call.mlir index 9372b2e421ad..b88c96678f70 100644 --- a/enzyme/test/MLIR/ProbProg/untraced_call.mlir +++ b/enzyme/test/MLIR/ProbProg/untraced_call.mlir @@ -4,14 +4,14 @@ module { func.func private @normal(%seed : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) func.func @test(%seed : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { - %s:2 = enzyme.sample @normal(%seed, %mean, %stddev) { name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %t:2 = enzyme.sample @normal(%s#0, %s#1, %stddev) { name="t" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %s:2 = impulse.sample @normal(%seed, %mean, %stddev) { name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = impulse.sample @normal(%s#0, %s#1, %stddev) { name="t" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %t#0, %t#1 : tensor<2xui64>, tensor } func.func @main(%seed : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { // CHECK: %0:2 = call @test.call(%arg0, %arg1, %arg2) : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %o:2 = enzyme.untracedCall @test(%seed, %mean, %stddev) { name = "test" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %o:2 = impulse.untracedCall @test(%seed, %mean, %stddev) { name = "test" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %o#0, %o#1 : tensor<2xui64>, tensor } } From e6a06ea750ae1fc7b2b5d165439605848224db58 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 16:38:49 -0500 Subject: [PATCH 2/5] format --- enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h | 3 +- .../MLIR/Dialect/Impulse/ImpulseDialect.cpp | 4 +- .../MLIR/Dialect/Impulse/ImpulseOps.cpp | 3 +- .../ImpulseAutoDiffOpInterfaceImpl.cpp | 15 +- .../Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp | 6 +- enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp | 227 +++++++++--------- .../Enzyme/MLIR/Interfaces/TransformUtils.h | 6 +- 7 files changed, 131 insertions(+), 133 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h b/enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h index 65a1309e55d7..8f58b54e0f00 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/Impulse.h @@ -1,4 +1,5 @@ -//===- Impulse.h - Impulse dialect --------------------------------*- C++ -*-===// +//===- Impulse.h - Impulse dialect --------------------------------*- C++ +//-*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseDialect.cpp b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseDialect.cpp index cdb8131e8da2..28a276dfeadc 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseDialect.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseDialect.cpp @@ -1,7 +1,7 @@ -#include "Dialect/Impulse/Impulse.h" #include "Dialect/Dialect.h" -#include "mlir/IR/DialectImplementation.h" +#include "Dialect/Impulse/Impulse.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" #include "Dialect/Impulse/ImpulseEnums.cpp.inc" diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp index 2ea8718344a2..9106b531fdcb 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp @@ -1,4 +1,5 @@ -//===- ImpulseOps.cpp - Impulse dialect ops ----------------------*- C++ -*-===// +//===- ImpulseOps.cpp - Impulse dialect ops ----------------------*- C++ +//-*-===// // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception diff --git a/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp index a68f6386c6ff..1509f3b84b54 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp @@ -1,11 +1,3 @@ -//===- ImpulseAutoDiffOpInterfaceImpl.cpp -------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - #include "Implementations/CoreDialectsAutoDiffImplementations.h" #include "Dialect/Impulse/Impulse.h" @@ -21,8 +13,7 @@ namespace { void mlir::enzyme::registerImpulseDialectAutoDiffInterface( DialectRegistry ®istry) { - registry.addExtension( - +[](MLIRContext *context, impulse::ImpulseDialect *) { - registerInterfaces(context); - }); + registry.addExtension(+[](MLIRContext *context, impulse::ImpulseDialect *) { + registerInterfaces(context); + }); } diff --git a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp index af271c7c3387..c65709b22de4 100644 --- a/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp +++ b/enzyme/Enzyme/MLIR/Integrations/c/EnzymeMLIR.cpp @@ -61,7 +61,7 @@ MlirAttribute enzymeSupportAttrGet(MlirContext ctx, EnzymeSupportKind kind, mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), upperBound); return wrap(mlir::impulse::SupportAttr::get(mlirCtx, supportKind, lowerAttr, - upperAttr)); + upperAttr)); } MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength, @@ -70,8 +70,8 @@ MlirAttribute enzymeHMCConfigAttrGet(MlirContext ctx, double trajectoryLength, auto trajectoryLengthAttr = mlir::FloatAttr::get(mlir::Float64Type::get(mlirCtx), trajectoryLength); - return wrap(mlir::impulse::HMCConfigAttr::get(mlirCtx, trajectoryLengthAttr, - adaptStepSize, adaptMassMatrix)); + return wrap(mlir::impulse::HMCConfigAttr::get( + mlirCtx, trajectoryLengthAttr, adaptStepSize, adaptMassMatrix)); } MlirAttribute enzymeNUTSConfigAttrGet(MlirContext ctx, int64_t maxTreeDepth, diff --git a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp index 2f1803331a68..0a30eb5cd454 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp @@ -62,10 +62,10 @@ SmallVector NUTSTreeState::getTypes() const { } Value impulse::conditionalDump(OpBuilder &builder, Location loc, Value value, - StringRef label, bool debugDump) { + StringRef label, bool debugDump) { if (debugDump) { return impulse::DumpOp::create(builder, loc, value.getType(), value, - builder.getStringAttr(label)) + builder.getStringAttr(label)) .getOutput(); } return value; @@ -137,8 +137,8 @@ static Value reverseRowsAndColumns(OpBuilder &builder, Location loc, } Value impulse::applyInverseMassMatrix(OpBuilder &builder, Location loc, - Value invMass, Value momentum, - RankedTensorType positionType) { + Value invMass, Value momentum, + RankedTensorType positionType) { if (!invMass) { return momentum; } @@ -164,8 +164,8 @@ Value impulse::applyInverseMassMatrix(OpBuilder &builder, Location loc, } Value impulse::computeKineticEnergy(OpBuilder &builder, Location loc, - Value momentum, Value invMass, - RankedTensorType positionType) { + Value momentum, Value invMass, + RankedTensorType positionType) { auto elemType = positionType.getElementType(); auto scalarType = RankedTensorType::get({}, elemType); @@ -188,8 +188,8 @@ Value impulse::computeKineticEnergy(OpBuilder &builder, Location loc, } Value impulse::computeMassMatrixSqrt(OpBuilder &builder, Location loc, - Value invMass, - RankedTensorType positionType) { + Value invMass, + RankedTensorType positionType) { if (!invMass) { return Value(); } @@ -213,7 +213,7 @@ Value impulse::computeMassMatrixSqrt(OpBuilder &builder, Location loc, auto reversedInvMass = reverseRowsAndColumns(builder, loc, invMass); auto L_reversed = impulse::CholeskyOp::create(builder, loc, invMassType, reversedInvMass, - /*lower=*/builder.getBoolAttr(true)); + /*lower=*/builder.getBoolAttr(true)); auto massMatrixSqrtInvT = reverseRowsAndColumns(builder, loc, L_reversed); auto identityMatrix = createIdentityMatrix(builder, loc, invMassType); auto massMatrixSqrt = impulse::TriangularSolveOp::create( @@ -223,17 +223,16 @@ Value impulse::computeMassMatrixSqrt(OpBuilder &builder, Location loc, /*unit_diagonal=*/builder.getBoolAttr(false), /*transpose_a=*/ impulse::TransposeAttr::get(builder.getContext(), - impulse::Transpose::TRANSPOSE)); + impulse::Transpose::TRANSPOSE)); return massMatrixSqrt; } } -std::pair impulse::sampleMomentum(OpBuilder &builder, Location loc, - Value rng, Value invMass, - Value massMatrixSqrt, - RankedTensorType positionType, - bool debugDump) { +std::pair +impulse::sampleMomentum(OpBuilder &builder, Location loc, Value rng, + Value invMass, Value massMatrixSqrt, + RankedTensorType positionType, bool debugDump) { auto elemType = positionType.getElementType(); auto scalarType = RankedTensorType::get({}, elemType); @@ -252,7 +251,7 @@ std::pair impulse::sampleMomentum(OpBuilder &builder, Location loc builder, loc, TypeRange{rng.getType(), positionType}, rng, zeroConst, oneConst, impulse::RngDistributionAttr::get(builder.getContext(), - impulse::RngDistribution::NORMAL)); + impulse::RngDistribution::NORMAL)); auto rngOut = randomOp.getOutputRngState(); auto eps = randomOp.getResult(); @@ -307,8 +306,8 @@ static Value scatterPositionToTrace(OpBuilder &builder, Location loc, DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(info.traceOffset))); SmallVector updateIndices{c0, traceOffset}; - result = impulse::DynamicUpdateSliceOp::create(builder, loc, traceType, - result, slice, updateIndices); + result = impulse::DynamicUpdateSliceOp::create( + builder, loc, traceType, result, slice, updateIndices); } return result; } @@ -344,16 +343,16 @@ static Value gatherPositionFromTrace(OpBuilder &builder, Location loc, DenseElementsAttr::get(i64TensorType, builder.getI64IntegerAttr(info.offset))); SmallVector updateIndices{c0, posOffset}; - result = impulse::DynamicUpdateSliceOp::create(builder, loc, positionType2d, - result, slice, updateIndices); + result = impulse::DynamicUpdateSliceOp::create( + builder, loc, positionType2d, result, slice, updateIndices); } return result; } GradientResult impulse::computePotentialAndGradient(OpBuilder &builder, - Location loc, Value position, - Value rng, - const HMCContext &ctx) { + Location loc, + Value position, Value rng, + const HMCContext &ctx) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); auto elemType = ctx.getElementType(); @@ -458,10 +457,11 @@ GradientResult impulse::computePotentialAndGradient(OpBuilder &builder, }; } -IntegrationResult impulse::computeIntegrationStep(OpBuilder &builder, Location loc, - const IntegratorState &leaf, - Value rng, Value direction, - const HMCContext &ctx) { +IntegrationResult impulse::computeIntegrationStep(OpBuilder &builder, + Location loc, + const IntegratorState &leaf, + Value rng, Value direction, + const HMCContext &ctx) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); auto elemType = ctx.getElementType(); @@ -477,12 +477,12 @@ IntegrationResult impulse::computeIntegrationStep(OpBuilder &builder, Location l ArrayRef shape = positionType.getShape(); auto stepSizeBroadcast = BroadcastOp::create(builder, loc, positionType, signedStepSize, - builder.getDenseI64ArrayAttr(shape)); + builder.getDenseI64ArrayAttr(shape)); auto halfStep = arith::MulFOp::create(builder, loc, halfConst, signedStepSize); auto halfStepBroadcast = BroadcastOp::create(builder, loc, positionType, halfStep, - builder.getDenseI64ArrayAttr(shape)); + builder.getDenseI64ArrayAttr(shape)); // 1. Half step momentum: p_half = p - 0.5 * eps * grad auto deltaP1 = @@ -507,7 +507,7 @@ IntegrationResult impulse::computeIntegrationStep(OpBuilder &builder, Location l } Value impulse::checkTurning(OpBuilder &builder, Location loc, Value pLeft, - Value pRight, Value pSum, const NUTSContext &ctx) { + Value pRight, Value pSum, const NUTSContext &ctx) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); auto elemType = ctx.getElementType(); @@ -555,16 +555,17 @@ Value impulse::checkTurning(OpBuilder &builder, Location loc, Value pLeft, } Value impulse::computeUniformTransitionProb(OpBuilder &builder, Location loc, - Value currentWeight, Value newWeight) { + Value currentWeight, + Value newWeight) { Value weightDiff = arith::SubFOp::create(builder, loc, newWeight, currentWeight); return impulse::LogisticOp::create(builder, loc, weightDiff.getType(), - weightDiff); + weightDiff); } Value impulse::computeBiasedTransitionProb(OpBuilder &builder, Location loc, - Value currentWeight, Value newWeight, - Value turning, Value diverging) { + Value currentWeight, Value newWeight, + Value turning, Value diverging) { auto resultType = cast(currentWeight.getType()); auto elemType = resultType.getElementType(); @@ -587,10 +588,10 @@ Value impulse::computeBiasedTransitionProb(OpBuilder &builder, Location loc, } NUTSTreeState impulse::combineTrees(OpBuilder &builder, Location loc, - const NUTSTreeState &tree, - const NUTSTreeState &subTree, Value direction, - Value rng, bool biased, - const NUTSContext &ctx) { + const NUTSTreeState &tree, + const NUTSTreeState &subTree, + Value direction, Value rng, bool biased, + const NUTSContext &ctx) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); auto elemType = ctx.getElementType(); @@ -604,18 +605,18 @@ NUTSTreeState impulse::combineTrees(OpBuilder &builder, Location loc, DenseElementsAttr::get(scalarType, builder.getFloatAttr(elemType, 1.0))); auto qLeft = impulse::SelectOp::create(builder, loc, positionType, direction, - tree.q_left, subTree.q_left); + tree.q_left, subTree.q_left); auto pLeft = impulse::SelectOp::create(builder, loc, positionType, direction, - tree.p_left, subTree.p_left); + tree.p_left, subTree.p_left); auto gradLeft = impulse::SelectOp::create( builder, loc, positionType, direction, tree.grad_left, subTree.grad_left); auto qRight = impulse::SelectOp::create(builder, loc, positionType, direction, - subTree.q_right, tree.q_right); + subTree.q_right, tree.q_right); auto pRight = impulse::SelectOp::create(builder, loc, positionType, direction, - subTree.p_right, tree.p_right); + subTree.p_right, tree.p_right); auto gradRight = impulse::SelectOp::create(builder, loc, positionType, direction, - subTree.grad_right, tree.grad_right); + subTree.grad_right, tree.grad_right); auto combinedWeight = impulse::LogAddExpOp::create( builder, loc, scalarType, tree.weight, subTree.weight); @@ -635,7 +636,7 @@ NUTSTreeState impulse::combineTrees(OpBuilder &builder, Location loc, builder, loc, TypeRange{rng.getType(), scalarType}, rng, zeroConst, oneConst, impulse::RngDistributionAttr::get(builder.getContext(), - impulse::RngDistribution::UNIFORM)); + impulse::RngDistribution::UNIFORM)); auto rngOut = randomOp.getOutputRngState(); auto uniformSample = randomOp.getResult(); @@ -644,10 +645,10 @@ NUTSTreeState impulse::combineTrees(OpBuilder &builder, Location loc, auto qProposal = impulse::SelectOp::create(builder, loc, positionType, acceptNew, - subTree.q_proposal, tree.q_proposal); + subTree.q_proposal, tree.q_proposal); auto gradProposal = impulse::SelectOp::create(builder, loc, positionType, acceptNew, - subTree.grad_proposal, tree.grad_proposal); + subTree.grad_proposal, tree.grad_proposal); auto UProposal = impulse::SelectOp::create( builder, loc, scalarType, acceptNew, subTree.U_proposal, tree.U_proposal); auto HProposal = impulse::SelectOp::create( @@ -698,8 +699,8 @@ NUTSTreeState impulse::combineTrees(OpBuilder &builder, Location loc, } InitialHMCState impulse::InitHMC(OpBuilder &builder, Location loc, Value rng, - const HMCContext &ctx, Value initialPosition, - bool debugDump) { + const HMCContext &ctx, Value initialPosition, + bool debugDump) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); auto elemType = ctx.getElementType(); @@ -859,8 +860,8 @@ InitialHMCState impulse::InitHMC(OpBuilder &builder, Location loc, Value rng, } MCMCKernelResult impulse::SampleHMC(OpBuilder &builder, Location loc, Value q, - Value grad, Value U, Value rng, - const HMCContext &ctx, bool debugDump) { + Value grad, Value U, Value rng, + const HMCContext &ctx, bool debugDump) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); auto elemType = ctx.getElementType(); @@ -922,7 +923,7 @@ MCMCKernelResult impulse::SampleHMC(OpBuilder &builder, Location loc, Value q, scalarType, rngTransition.getType()}; auto forLoopOp = impulse::ForOp::create(builder, loc, loopResultTypes, c0, numSteps, c1, - ValueRange{q, p0, grad, U, rngTransition}); + ValueRange{q, p0, grad, U, rngTransition}); Block *loopBody = builder.createBlock(&forLoopOp.getRegion()); loopBody->addArgument(i64TensorType, loc); // iv @@ -976,7 +977,7 @@ MCMCKernelResult impulse::SampleHMC(OpBuilder &builder, Location loc, Value q, builder, loc, TypeRange{rngAfterLeapfrog.getType(), scalarType}, rngAfterLeapfrog, zeroConst, oneConst, impulse::RngDistributionAttr::get(builder.getContext(), - impulse::RngDistribution::UNIFORM)); + impulse::RngDistribution::UNIFORM)); auto randUniform = randomOp.getResult(); // accepted = u < α @@ -985,18 +986,18 @@ MCMCKernelResult impulse::SampleHMC(OpBuilder &builder, Location loc, Value q, // 8. Select between original and proposal auto qFinal = impulse::SelectOp::create(builder, loc, positionType, - acceptedTensor, qProposal, q); - auto gradFinal = impulse::SelectOp::create(builder, loc, positionType, - acceptedTensor, gradProposal, grad); + acceptedTensor, qProposal, q); + auto gradFinal = impulse::SelectOp::create( + builder, loc, positionType, acceptedTensor, gradProposal, grad); auto UFinal = impulse::SelectOp::create(builder, loc, scalarType, - acceptedTensor, UProposal, U); + acceptedTensor, UProposal, U); return {qFinal, gradFinal, UFinal, acceptedTensor, accProb, rngNext}; } MCMCKernelResult impulse::SampleNUTS(OpBuilder &builder, Location loc, Value q, - Value grad, Value U, Value rng, - const NUTSContext &ctx, bool debugDump) { + Value grad, Value U, Value rng, + const NUTSContext &ctx, bool debugDump) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); auto elemType = ctx.getElementType(); @@ -1085,8 +1086,8 @@ MCMCKernelResult impulse::SampleNUTS(OpBuilder &builder, Location loc, Value q, } NUTSTreeState impulse::buildBaseTree(OpBuilder &builder, Location loc, - const IntegratorState &leaf, Value rng, - Value direction, const NUTSContext &ctx) { + const IntegratorState &leaf, Value rng, + Value direction, const NUTSContext &ctx) { auto positionType = ctx.getPositionType(); auto scalarType = ctx.getScalarType(); auto elemType = ctx.getElementType(); @@ -1163,25 +1164,25 @@ NUTSTreeState impulse::buildBaseTree(OpBuilder &builder, Location loc, } IntegratorState impulse::getLeafFromTree(OpBuilder &builder, Location loc, - const NUTSTreeState &tree, - Value direction, const NUTSContext &ctx) { + const NUTSTreeState &tree, + Value direction, + const NUTSContext &ctx) { auto positionType = ctx.getPositionType(); auto leafQ = impulse::SelectOp::create(builder, loc, positionType, direction, - tree.q_right, tree.q_left); + tree.q_right, tree.q_left); auto leafP = impulse::SelectOp::create(builder, loc, positionType, direction, - tree.p_right, tree.p_left); + tree.p_right, tree.p_left); auto leafGrad = impulse::SelectOp::create( builder, loc, positionType, direction, tree.grad_right, tree.grad_left); return {leafQ, leafP, leafGrad}; } -SubtreeBuildResult impulse::buildIterativeSubtree(OpBuilder &builder, Location loc, - const NUTSTreeState &initialTree, - Value direction, Value pCkpts, - Value pSumCkpts, - const NUTSContext &ctx, - bool debugDump) { +SubtreeBuildResult +impulse::buildIterativeSubtree(OpBuilder &builder, Location loc, + const NUTSTreeState &initialTree, + Value direction, Value pCkpts, Value pSumCkpts, + const NUTSContext &ctx, bool debugDump) { auto i1TensorType = RankedTensorType::get({}, builder.getI1Type()); auto i64TensorType = RankedTensorType::get({}, builder.getI64Type()); auto pCkptsType = cast(pCkpts.getType()); @@ -1302,7 +1303,7 @@ SubtreeBuildResult impulse::buildIterativeSubtree(OpBuilder &builder, Location l updatedTree.turning = impulse::SelectOp::create(builder, loc, i1TensorType, isFirstLeaf, - newLeaf.turning, iterativeTurning); + newLeaf.turning, iterativeTurning); auto nextLeafIdx = arith::AddIOp::create(builder, loc, bodyLeafIdx, oneI64); @@ -1328,9 +1329,10 @@ SubtreeBuildResult impulse::buildIterativeSubtree(OpBuilder &builder, Location l } SubtreeBuildResult impulse::doubleTree(OpBuilder &builder, Location loc, - const NUTSTreeState &tree, Value direction, - Value pCkpts, Value pSumCkpts, - const NUTSContext &ctx, bool debugDump) { + const NUTSTreeState &tree, + Value direction, Value pCkpts, + Value pSumCkpts, const NUTSContext &ctx, + bool debugDump) { auto rngSplit2 = impulse::RandomSplitOp::create( builder, loc, TypeRange{tree.rng.getType(), tree.rng.getType()}, tree.rng); @@ -1351,8 +1353,8 @@ SubtreeBuildResult impulse::doubleTree(OpBuilder &builder, Location loc, } NUTSTreeState impulse::buildTree(OpBuilder &builder, Location loc, - const NUTSTreeState &initialTree, - const NUTSContext &ctx, bool debugDump) { + const NUTSTreeState &initialTree, + const NUTSContext &ctx, bool debugDump) { auto elemType = cast(ctx.stepSize.getType()).getElementType(); auto F64TensorType = RankedTensorType::get({}, elemType); @@ -1441,7 +1443,7 @@ NUTSTreeState impulse::buildTree(OpBuilder &builder, Location loc, builder, loc, TypeRange{rngDir.getType(), F64TensorType}, rngDir, zeroConst, oneConst, impulse::RngDistributionAttr::get(builder.getContext(), - impulse::RngDistribution::UNIFORM)); + impulse::RngDistribution::UNIFORM)); auto direction = arith::CmpFOp::create(builder, loc, arith::CmpFPredicate::OLT, directionRandom.getResult(), halfConst); @@ -1463,8 +1465,9 @@ NUTSTreeState impulse::buildTree(OpBuilder &builder, Location loc, return NUTSTreeState::fromValues(results); } -std::pair -impulse::leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx) { +std::pair impulse::leafIdxToCheckpointIdxs(OpBuilder &builder, + Location loc, + Value leafIdx) { auto i64TensorType = cast(leafIdx.getType()); auto oneConst = arith::ConstantOp::create( @@ -1500,9 +1503,9 @@ impulse::leafIdxToCheckpointIdxs(OpBuilder &builder, Location loc, Value leafIdx } Value impulse::checkIterativeTurning(OpBuilder &builder, Location loc, Value p, - Value pSum, Value pCkpts, Value pSumCkpts, - Value idxMin, Value idxMax, - const NUTSContext &ctx, bool debugDump) { + Value pSum, Value pCkpts, Value pSumCkpts, + Value idxMin, Value idxMax, + const NUTSContext &ctx, bool debugDump) { auto positionType = ctx.getPositionType(); auto i64TensorType = RankedTensorType::get({}, builder.getI64Type()); auto i1TensorType = RankedTensorType::get({}, builder.getI1Type()); @@ -1570,9 +1573,9 @@ Value impulse::checkIterativeTurning(OpBuilder &builder, Location loc, Value p, std::pair impulse::updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, - Value ckptIdxMax, Value p, Value pSum, Value pCkpts, - Value pSumCkpts, const NUTSContext &ctx, - bool debugDump) { + Value ckptIdxMax, Value p, Value pSum, Value pCkpts, + Value pSumCkpts, const NUTSContext &ctx, + bool debugDump) { auto i64TensorType = RankedTensorType::get({}, builder.getI64Type()); auto oneI64 = arith::ConstantOp::create( builder, loc, i64TensorType, @@ -1601,7 +1604,7 @@ impulse::updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, ValueRange{ckptIdxMax, zeroI64}); impulse::YieldOp::create(builder, loc, - ValueRange{updatedPCkpts, updatedPSumCkpts}); + ValueRange{updatedPCkpts, updatedPSumCkpts}); } { Block *falseBranch = builder.createBlock(&ifOp.getFalseBranch()); @@ -1619,7 +1622,7 @@ impulse::updateCheckpoints(OpBuilder &builder, Location loc, Value leafIdx, } DualAveragingState impulse::initDualAveraging(OpBuilder &builder, Location loc, - Value stepSize) { + Value stepSize) { auto stepSizeType = cast(stepSize.getType()); auto elemType = stepSizeType.getElementType(); auto scalarType = RankedTensorType::get({}, elemType); @@ -1652,8 +1655,8 @@ DualAveragingState impulse::initDualAveraging(OpBuilder &builder, Location loc, DualAveragingState impulse::updateDualAveraging(OpBuilder &builder, Location loc, - const DualAveragingState &state, Value acceptProb, - const DualAveragingConfig &config) { + const DualAveragingState &state, Value acceptProb, + const DualAveragingConfig &config) { // Dual Averaging update: // g = target_accept_prob - accept_prob // g_avg = (1 - 1/(t+t0)) * g_avg + g/(t+t0) @@ -1744,14 +1747,14 @@ impulse::updateDualAveraging(OpBuilder &builder, Location loc, } Value impulse::getStepSizeFromDualAveraging(OpBuilder &builder, Location loc, - const DualAveragingState &state, - bool final) { + const DualAveragingState &state, + bool final) { Value logStepSize = final ? state.log_step_size_avg : state.log_step_size; return math::ExpOp::create(builder, loc, logStepSize); } WelfordState impulse::initWelford(OpBuilder &builder, Location loc, - int64_t positionSize, bool diagonal) { + int64_t positionSize, bool diagonal) { auto elemType = builder.getF64Type(); auto i64TensorType = RankedTensorType::get({}, builder.getI64Type()); auto meanType = RankedTensorType::get({positionSize}, elemType); @@ -1782,8 +1785,8 @@ WelfordState impulse::initWelford(OpBuilder &builder, Location loc, } WelfordState impulse::updateWelford(OpBuilder &builder, Location loc, - const WelfordState &state, Value sample, - const WelfordConfig &config) { + const WelfordState &state, Value sample, + const WelfordConfig &config) { // Algorithm: // n = n + 1 // delta_pre = sample - mean @@ -1804,8 +1807,8 @@ WelfordState impulse::updateWelford(OpBuilder &builder, Location loc, auto scalarType = RankedTensorType::get({}, elemType); Value nFloat = arith::SIToFPOp::create(builder, loc, scalarType, nNew); - Value nBroadcast = BroadcastOp::create(builder, loc, sampleType, - nFloat, sampleType.getShape()); + Value nBroadcast = BroadcastOp::create(builder, loc, sampleType, nFloat, + sampleType.getShape()); Value deltaPre = arith::SubFOp::create(builder, loc, sample, state.mean); @@ -1835,8 +1838,8 @@ WelfordState impulse::updateWelford(OpBuilder &builder, Location loc, } Value impulse::finalizeWelford(OpBuilder &builder, Location loc, - const WelfordState &state, - const WelfordConfig &config) { + const WelfordState &state, + const WelfordConfig &config) { // Compute sample covariance: cov = m2 / (n - 1) auto m2Type = cast(state.m2.getType()); auto elemType = m2Type.getElementType(); @@ -1850,8 +1853,8 @@ Value impulse::finalizeWelford(OpBuilder &builder, Location loc, Value nMinus1Float = arith::SIToFPOp::create(builder, loc, scalarType, nMinus1); - Value nMinus1Bcast = BroadcastOp::create( - builder, loc, m2Type, nMinus1Float, m2Type.getShape()); + Value nMinus1Bcast = BroadcastOp::create(builder, loc, m2Type, nMinus1Float, + m2Type.getShape()); Value cov = arith::DivFOp::create(builder, loc, state.m2, nMinus1Bcast); @@ -1870,8 +1873,8 @@ Value impulse::finalizeWelford(OpBuilder &builder, Location loc, Value nPlusFive = arith::AddFOp::create(builder, loc, nFloat, fiveConst); Value scale = arith::DivFOp::create(builder, loc, nFloat, nPlusFive); - Value scaleBcast = BroadcastOp::create(builder, loc, m2Type, scale, - m2Type.getShape()); + Value scaleBcast = + BroadcastOp::create(builder, loc, m2Type, scale, m2Type.getShape()); Value scaledCov = arith::MulFOp::create(builder, loc, scaleBcast, cov); auto shrinkageBaseConst = arith::ConstantOp::create( @@ -1882,13 +1885,13 @@ Value impulse::finalizeWelford(OpBuilder &builder, Location loc, arith::DivFOp::create(builder, loc, shrinkageBaseConst, nPlusFive); if (config.diagonal) { - Value shrinkageBcast = BroadcastOp::create( - builder, loc, m2Type, shrinkage, m2Type.getShape()); + Value shrinkageBcast = BroadcastOp::create(builder, loc, m2Type, + shrinkage, m2Type.getShape()); cov = arith::AddFOp::create(builder, loc, scaledCov, shrinkageBcast); } else { Value identity = createIdentityMatrix(builder, loc, m2Type); - Value shrinkageBcast = BroadcastOp::create( - builder, loc, m2Type, shrinkage, m2Type.getShape()); + Value shrinkageBcast = BroadcastOp::create(builder, loc, m2Type, + shrinkage, m2Type.getShape()); Value shrinkageI = arith::MulFOp::create(builder, loc, shrinkageBcast, identity); cov = arith::AddFOp::create(builder, loc, scaledCov, shrinkageI); @@ -1950,8 +1953,8 @@ SmallVector impulse::buildAdaptationSchedule(int64_t numSteps) { } Value impulse::unconstrainPosition(OpBuilder &builder, Location loc, - Value constrained, - ArrayRef supports) { + Value constrained, + ArrayRef supports) { bool hasConstraints = false; for (const auto &info : supports) { if (info.support && info.support.getKind() != impulse::SupportKind::REAL) { @@ -2039,8 +2042,8 @@ Value impulse::unconstrainPosition(OpBuilder &builder, Location loc, } Value impulse::constrainPosition(OpBuilder &builder, Location loc, - Value unconstrained, - ArrayRef supports) { + Value unconstrained, + ArrayRef supports) { bool hasConstraints = false; for (const auto &info : supports) { if (info.support && info.support.getKind() != impulse::SupportKind::REAL) { @@ -2129,8 +2132,8 @@ Value impulse::constrainPosition(OpBuilder &builder, Location loc, } Value impulse::computeTotalJacobianCorrection(OpBuilder &builder, Location loc, - Value unconstrained, - ArrayRef supports) { + Value unconstrained, + ArrayRef supports) { auto inputType = cast(unconstrained.getType()); auto elemType = inputType.getElementType(); auto scalarType = RankedTensorType::get({}, elemType); diff --git a/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.h b/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.h index e07257126139..a44ff1a252a4 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/TransformUtils.h @@ -24,10 +24,12 @@ namespace enzyme { namespace transforms { /// Get the unconstrained size given a constrained size and support kind. -int64_t getUnconstrainedSize(int64_t constrainedSize, impulse::SupportKind kind); +int64_t getUnconstrainedSize(int64_t constrainedSize, + impulse::SupportKind kind); /// Get the constrained size given an unconstrained size and support kind. -int64_t getConstrainedSize(int64_t unconstrainedSize, impulse::SupportKind kind); +int64_t getConstrainedSize(int64_t unconstrainedSize, + impulse::SupportKind kind); /// Transform from constrained to unconstrained space. Value unconstrain(OpBuilder &builder, Location loc, Value constrained, From fc7749d5331361713e044279ef92f419b7ab982e Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 16:54:58 -0500 Subject: [PATCH 3/5] undo enzyme.dump move --- enzyme/BUILD | 12 ---------- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 18 +++++++++++++++ .../Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td | 22 ------------------- .../CoreDialectsAutoDiffImplementations.cpp | 1 - .../CoreDialectsAutoDiffImplementations.h | 1 - .../MLIR/Implementations/EnzymeDerivatives.td | 1 + .../ImpulseAutoDiffOpInterfaceImpl.cpp | 19 ---------------- .../Implementations/ImpulseDerivatives.td | 3 --- enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp | 4 ++-- enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h | 2 +- 10 files changed, 22 insertions(+), 61 deletions(-) delete mode 100644 enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp delete mode 100644 enzyme/Enzyme/MLIR/Implementations/ImpulseDerivatives.td diff --git a/enzyme/BUILD b/enzyme/BUILD index 2384ff2ef2ee..3baa8e5719cb 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -943,17 +943,6 @@ gentbl_cc_library( deps = [":ImplementationsCommonTdFiles"], ) -gentbl_cc_library( - name = "impulse-derivatives", - strip_include_prefix = "Enzyme/MLIR", - tbl_outs = [( - ["-gen-mlir-derivatives"], - "Enzyme/MLIR/Implementations/ImpulseDerivatives.inc", - )], - tblgen = ":enzyme-tblgen", - td_file = "Enzyme/MLIR/Implementations/ImpulseDerivatives.td", - deps = [":ImplementationsCommonTdFiles"], -) cc_library( name = "EnzymeMLIR", @@ -1000,7 +989,6 @@ cc_library( ":cf-derivatives", ":complex-derivatives", ":enzyme-derivatives", - ":impulse-derivatives", ":func-derivatives", ":linalg-derivatives", ":llvm-derivatives", diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index f28396026faf..88ea2f03ba5c 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -415,6 +415,24 @@ def ExtractOp : Enzyme_Op<"extract"> { }]; } +def DumpOp : Enzyme_Op<"dump"> { + let summary = "Debug operation to dump a tensor value at runtime"; + let description = [{ + Debug operation that dumps a tensor value with a label. + }]; + + let arguments = (ins + AnyType:$value, + StrAttr:$label + ); + + let results = (outs AnyType:$output); + + let assemblyFormat = [{ + $value attr-dict `:` functional-type($value, results) + }]; +} + def AffineAtomicRMWOp : Enzyme_Op<"affine_atomic_rmw"> { let summary = "affine atomic rmw operation"; let description = [{ diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td index 9d9944f4124e..bc898d6167e0 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td @@ -630,26 +630,4 @@ def PopcountOp : Impulse_Op<"popcount", [Pure, SameOperandsAndResultElementType, }]; } -//===----------------------------------------------------------------------===// -// Debug ops -//===----------------------------------------------------------------------===// - -def DumpOp : Impulse_Op<"dump"> { - let summary = "Debug operation to dump a tensor value at runtime"; - let description = [{ - Debug operation that dumps a tensor value with a label. - }]; - - let arguments = (ins - AnyType:$value, - StrAttr:$label - ); - - let results = (outs AnyType:$output); - - let assemblyFormat = [{ - $value attr-dict `:` functional-type($value, results) - }]; -} - #endif // IMPULSE_OPS diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index d39fd3da387b..f88d212e8c76 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -481,5 +481,4 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces( enzyme::registerFuncDialectAutoDiffInterface(registry); enzyme::registerTensorDialectAutoDiffInterface(registry); enzyme::registerEnzymeDialectAutoDiffInterface(registry); - enzyme::registerImpulseDialectAutoDiffInterface(registry); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index f01b3369a234..dabce9ecd520 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -263,7 +263,6 @@ void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry); void registerEnzymeDialectAutoDiffInterface(DialectRegistry ®istry); -void registerImpulseDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/EnzymeDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/EnzymeDerivatives.td index 897e58ff4dfc..32098bce5aab 100644 --- a/enzyme/Enzyme/MLIR/Implementations/EnzymeDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/EnzymeDerivatives.td @@ -1,3 +1,4 @@ include "Common.td" def : InactiveOp<"enzyme", "IgnoreDerivativesOp">; +def : InactiveOp<"enzyme", "DumpOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp deleted file mode 100644 index 1509f3b84b54..000000000000 --- a/enzyme/Enzyme/MLIR/Implementations/ImpulseAutoDiffOpInterfaceImpl.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "Implementations/CoreDialectsAutoDiffImplementations.h" - -#include "Dialect/Impulse/Impulse.h" -#include "mlir/IR/TypeSupport.h" - -using namespace mlir; -using namespace mlir::enzyme; -using namespace mlir::impulse; - -namespace { -#include "Implementations/ImpulseDerivatives.inc" -} // namespace - -void mlir::enzyme::registerImpulseDialectAutoDiffInterface( - DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *context, impulse::ImpulseDialect *) { - registerInterfaces(context); - }); -} diff --git a/enzyme/Enzyme/MLIR/Implementations/ImpulseDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ImpulseDerivatives.td deleted file mode 100644 index 1e04268b9cd7..000000000000 --- a/enzyme/Enzyme/MLIR/Implementations/ImpulseDerivatives.td +++ /dev/null @@ -1,3 +0,0 @@ -include "Common.td" - -def : InactiveOp<"impulse", "DumpOp">; diff --git a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp index 0a30eb5cd454..c13a94c219e4 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.cpp @@ -64,8 +64,8 @@ SmallVector NUTSTreeState::getTypes() const { Value impulse::conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump) { if (debugDump) { - return impulse::DumpOp::create(builder, loc, value.getType(), value, - builder.getStringAttr(label)) + return enzyme::DumpOp::create(builder, loc, value.getType(), value, + builder.getStringAttr(label)) .getOutput(); } return value; diff --git a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h index d79d0a9118cc..2c67f4a9bf9c 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/HMCUtils.h @@ -227,7 +227,7 @@ struct SubtreeBuildResult { }; /// Conditionally dump a value for debugging. -/// Emits an impulse::DumpOp if `debugDump` is true; otherwise has no effect. +/// Emits an enzyme::DumpOp if `debugDump` is true; otherwise has no effect. Value conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label, bool debugDump); From acf3a217328eaf31e0b038f4deee6d786791e454 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 18:58:48 -0500 Subject: [PATCH 4/5] fix build --- enzyme/Enzyme/MLIR/Dialect/Impulse/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/CMakeLists.txt b/enzyme/Enzyme/MLIR/Dialect/Impulse/CMakeLists.txt index 728df4d51c01..33d32f35062d 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Impulse/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_dialect(ImpulseOps impulse) -add_mlir_doc(ImpulseDialect -gen-dialect-doc ImpulseDialect Enzyme/) -add_mlir_doc(ImpulseOps -gen-op-doc ImpulseOps Enzyme/) +add_mlir_doc(Dialect ImpulseDialect Impulse/ -gen-dialect-doc) +add_mlir_doc(ImpulseOps ImpulseOps Impulse/ -gen-op-doc) set(LLVM_TARGET_DEFINITIONS ImpulseOps.td) mlir_tablegen(ImpulseAttributeInterfaces.h.inc -gen-attr-interface-decls) From 6f759e3ecc490f54468a4f4a93b1bf23f19a57cc Mon Sep 17 00:00:00 2001 From: sbrantq Date: Mon, 30 Mar 2026 19:42:03 -0500 Subject: [PATCH 5/5] fix up build --- enzyme/BUILD | 2 +- enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td | 2 +- enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/enzyme/BUILD b/enzyme/BUILD index 3baa8e5719cb..dada5822a704 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -447,7 +447,7 @@ td_library( "Enzyme/MLIR/Dialect/Impulse/Dialect.td", "Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td", ], - includes = [".", "Enzyme/MLIR/Dialect"], + includes = ["."], deps = [ ":EnzymeDialectTdFiles", "@llvm-project//mlir:OpBaseTdFiles", diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td index 565ef66c8b4f..3dcb29ec5376 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td @@ -10,7 +10,7 @@ #define IMPULSE_ENUMS include "mlir/IR/EnumAttr.td" -include "Impulse/Dialect.td" +include "Dialect.td" class Impulse_Attr traits = []> : AttrDef { diff --git a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td index bc898d6167e0..2e935fed4034 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td @@ -9,8 +9,8 @@ #ifndef IMPULSE_OPS #define IMPULSE_OPS -include "Impulse/ImpulseEnums.td" -include "Impulse/Dialect.td" +include "ImpulseEnums.td" +include "Dialect.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/EnumAttr.td"