diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml index 2bbf2383..587fb8d4 100644 --- a/.github/workflows/benchmark_pr.yml +++ b/.github/workflows/benchmark_pr.yml @@ -1,9 +1,6 @@ name: Benchmark pull request -on: - pull_request: - branches: - - master +on: [pull_request] permissions: pull-requests: write diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2356955c..125f32f3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,8 +1,6 @@ name: CI on: pull_request: - branches: - - master push: branches: - master diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index f645e44a..3166ab9d 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -223,38 +223,20 @@ Returns the canonical e-class id for a given e-class. @inline Base.getindex(g::EGraph, i::Id) = g.classes[IdKey(find(g, i))] -# function canonicalize(g::EGraph, n::VecExpr)::VecExpr -# if !v_isexpr(n) -# v_hash!(n) -# return n -# end -# l = v_arity(n) -# new_n = v_new(l) -# v_set_flag!(new_n, v_flags(n)) -# v_set_head!(new_n, v_head(n)) -# for i in v_children_range(n) -# @inbounds new_n[i] = find(g, n[i]) -# end -# v_hash!(new_n) -# new_n -# end - function canonicalize!(g::EGraph, n::VecExpr) - v_isexpr(n) || @goto ret - for i in (VECEXPR_META_LENGTH + 1):length(n) - @inbounds n[i] = find(g, n[i]) + if v_isexpr(n) + for i in (VECEXPR_META_LENGTH + 1):length(n) + @inbounds n[i] = find(g, n[i]) + end end - v_unset_hash!(n) - @label ret - v_hash!(n) n end function lookup(g::EGraph, n::VecExpr)::Id canonicalize!(g, n) - h = IdKey(v_hash(n)) - haskey(g.memo, n) ? find(g, g.memo[n]) : 0 + id = get(g.memo, n, zero(Id)) + iszero(id) ? id : find(g, id) end @@ -272,9 +254,10 @@ Inserts an e-node in an [`EGraph`](@ref) """ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis} canonicalize!(g, n) - - haskey(g.memo, n) && return g.memo[n] - + + id = get(g.memo, n, zero(Id)) + iszero(id) || return id + if should_copy n = copy(n) end @@ -319,28 +302,21 @@ function addexpr!(g::EGraph, se)::Id se isa EClass && return se.id e = preprocess(se) - n = if isexpr(e) - args = iscall(e) ? arguments(e) : children(e) - ar = length(args) - n = v_new(ar) - v_set_flag!(n, VECEXPR_FLAG_ISTREE) - iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL) - - h = iscall(e) ? operation(e) : head(e) - v_set_head!(n, add_constant!(g, h)) - - # get the signature from op and arity - v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar))) - - for i in v_children_range(n) - @inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH]) - end - n - else # constant enode - VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]) + isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), add_constant!(g, e)]), false) # constant enode + + args = iscall(e) ? arguments(e) : children(e) + ar = length(args) + n = v_new(ar) + v_set_flag!(n, VECEXPR_FLAG_ISTREE) + iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL) + h = iscall(e) ? operation(e) : head(e) + v_set_head!(n, add_constant!(g, h)) + # get the signature from op and arity + v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar))) + for i in v_children_range(n) + @inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH]) end - id = add!(g, n, false) - return id + add!(g, n, false) end """ @@ -431,9 +407,8 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp while !isempty(g.pending) (node::VecExpr, eclass_id::Id) = pop!(g.pending) canonicalize!(g, node) - if haskey(g.memo, node) - old_class_id = g.memo[node] - g.memo[node] = eclass_id + old_class_id = get!(g.memo, node, eclass_id) + if old_class_id != eclass_id did_something = union!(g, old_class_id, eclass_id) # TODO unique! can node dedup be moved here? compare performance # did_something && unique!(g[eclass_id].nodes) @@ -473,9 +448,8 @@ function check_memo(g::EGraph)::Bool for (id, class) in g.classes @assert id.val == class.id for node in class.nodes - if haskey(test_memo, node) - old_id = test_memo[node] - test_memo[node] = id.val + old_id = get!(test_memo, node, id.val) + if old_id != id.val @assert find(g, old_id) == find(g, id.val) "Unexpected equivalence $node $(g[find(g, id.val)].nodes) $(g[find(g, old_id)].nodes)" end end @@ -483,7 +457,7 @@ function check_memo(g::EGraph)::Bool for (node, id) in test_memo @assert id == find(g, id) - @assert id == find(g, g.memo[node]) + @assert id == find(g, g.memo[node]) "Entry for $node at $id in test_memo was incorrect." end true diff --git a/src/Patterns.jl b/src/Patterns.jl index cf0a9ee4..2568f176 100644 --- a/src/Patterns.jl +++ b/src/Patterns.jl @@ -26,7 +26,7 @@ isground(p::AbstractPat) = false struct PatLiteral <: AbstractPat value n::VecExpr - PatLiteral(val) = new(val, VecExpr(Id[0, 0, 0, hash(val)])) + PatLiteral(val) = new(val, VecExpr(Id[0, 0, hash(val)])) end PatLiteral(p::AbstractPat) = throw(DomainError(p, "Cannot construct a pattern literal of another pattern object.")) diff --git a/src/vecexpr.jl b/src/vecexpr.jl index 1e8a83ba..53183bcf 100644 --- a/src/vecexpr.jl +++ b/src/vecexpr.jl @@ -34,14 +34,12 @@ const Id = UInt64 end An e-node is represented by `Vector{Id}` where: -* Position 1 stores the hash of the `VecExpr`. -* Position 2 stores the bit flags (`isexpr` or `iscall`). -* Position 3 stores the signature -* Position 4 stores the hash of the `head` (if `isexpr`) or node value in the e-graph constants. +* Position 1 stores the bit flags (`isexpr` or `iscall`). +* Position 2 stores the signature +* Position 3 stores the hash of the `head` (if `isexpr`) or node value in the e-graph constants. * The rest of the positions store the e-class ids of the children nodes. The expression is represented as an array of integers to improve performance. -The hash value for the VecExpr is cached in the first position for faster lookup performance in dictionaries. """ struct VecExpr data::Vector{Id} @@ -49,12 +47,12 @@ end const VECEXPR_FLAG_ISTREE = 0x01 const VECEXPR_FLAG_ISCALL = 0x10 -const VECEXPR_META_LENGTH = 4 +const VECEXPR_META_LENGTH = 3 -@inline v_flags(n::VecExpr)::Id = @inbounds n.data[2] -@inline v_unset_flags!(n::VecExpr) = @inbounds (n.data[2] = 0) +@inline v_flags(n::VecExpr)::Id = @inbounds n.data[1] +@inline v_unset_flags!(n::VecExpr) = @inbounds (n.data[1] = 0) @inline v_check_flags(n::VecExpr, flag::Id)::Bool = !iszero(v_flags(n) & flags) -@inline v_set_flag!(n::VecExpr, flag)::Id = @inbounds (n.data[2] = n.data[2] | flag) +@inline v_set_flag!(n::VecExpr, flag)::Id = @inbounds (n.data[1] = n.data[1] | flag) """Returns `true` if the e-node ID points to a an expression tree.""" @inline v_isexpr(n::VecExpr)::Bool = !iszero(v_flags(n) & VECEXPR_FLAG_ISTREE) @@ -65,33 +63,15 @@ const VECEXPR_META_LENGTH = 4 """Number of children in the e-node.""" @inline v_arity(n::VecExpr)::Int = length(n.data) - VECEXPR_META_LENGTH -""" -Compute the hash of a `VecExpr` and store it as the first element. -""" -@inline function v_hash!(n::VecExpr)::Id - if iszero(n.data[1]) - n.data[1] = hash(@view n.data[2:end]) - else - # h = hash(@view n[2:end]) - # @assert h == n[1] - n.data[1] - end -end - -"""The hash of the e-node.""" -@inline v_hash(n::VecExpr)::Id = @inbounds n.data[1] -Base.hash(n::VecExpr) = v_hash(n) # IdKey not necessary here -Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end]) - -"""Set e-node hash to zero.""" -@inline v_unset_hash!(n::VecExpr)::Id = @inbounds (n.data[1] = Id(0)) +Base.hash(n::VecExpr) = hash(n.data) +Base.:(==)(a::VecExpr, b::VecExpr) = a.data == b.data """E-class IDs of the children of the e-node.""" @inline v_children(n::VecExpr) = @view n.data[(VECEXPR_META_LENGTH + 1):end] -@inline v_signature(n::VecExpr)::Id = @inbounds n.data[3] +@inline v_signature(n::VecExpr)::Id = @inbounds n.data[2] -@inline v_set_signature!(n::VecExpr, sig::Id) = @inbounds (n.data[3] = sig) +@inline v_set_signature!(n::VecExpr, sig::Id) = @inbounds (n.data[2] = sig) "The constant ID of the operation of the e-node, or the e-node ." @inline v_head(n::VecExpr)::Id = @inbounds n.data[VECEXPR_META_LENGTH] @@ -102,7 +82,6 @@ Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:en """Construct a new, empty `VecExpr` with `len` children.""" @inline function v_new(len::Int)::VecExpr n = VecExpr(Vector{Id}(undef, len + VECEXPR_META_LENGTH)) - v_unset_hash!(n) v_unset_flags!(n) n end