diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 0a86d91f..5937ddad 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -224,22 +224,6 @@ Returns the canonical e-class id for a given e-class. @inline Base.getindex(g::EGraph, i::Id) = g.classes[IdKey(find(g, i))] -# function canonicalize(g::EGraph, n::VecExpr)::VecExpr -# if !v_isexpr(n) -# v_hash!(n) -# return n -# end -# l = v_arity(n) -# new_n = v_new(l) -# v_set_flag!(new_n, v_flags(n)) -# v_set_head!(new_n, v_head(n)) -# for i in v_children_range(n) -# @inbounds new_n[i] = find(g, n[i]) -# end -# v_hash!(new_n) -# new_n -# end - function canonicalize!(g::EGraph, n::VecExpr) v_isexpr(n) || @goto ret for i in (VECEXPR_META_LENGTH + 1):length(n) @@ -253,19 +237,16 @@ end function lookup(g::EGraph, n::VecExpr)::Id canonicalize!(g, n) - h = IdKey(v_hash(n)) - haskey(g.memo, n) ? find(g, g.memo[n]) : 0 + id = get(g.memo, n, zero(Id)) + iszero(id) ? id : find(g, id) end function add_class_by_op(g::EGraph, n, eclass_id) key = IdKey(v_signature(n)) - if haskey(g.classes_by_op, key) - push!(g.classes_by_op[key], eclass_id) - else - g.classes_by_op[key] = [eclass_id] - end + vec = get!(g.classes_by_op, key, Vector{Id}()) + push!(vec, eclass_id) end """ @@ -274,7 +255,8 @@ Inserts an e-node in an [`EGraph`](@ref) function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis} canonicalize!(g, n) - haskey(g.memo, n) && return g.memo[n] + id = get(g.memo, n, zero(Id)) + iszero(id) || return id if should_copy n = copy(n) @@ -291,7 +273,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool) g.memo[n] = id add_class_by_op(g, n, id) - eclass = EClass{Analysis}(id, VecExpr[n], Pair{VecExpr,Id}[], make(g, n)) + eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n)) g.classes[IdKey(id)] = eclass modify!(g, eclass) push!(g.pending, n => id) @@ -320,28 +302,22 @@ function addexpr!(g::EGraph, se)::Id se isa EClass && return se.id e = preprocess(se) - n = if isexpr(e) - args = iscall(e) ? arguments(e) : children(e) - ar = length(args) - n = v_new(ar) - v_set_flag!(n, VECEXPR_FLAG_ISTREE) - iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL) - - h = iscall(e) ? operation(e) : head(e) - v_set_head!(n, add_constant!(g, h)) - - # get the signature from op and arity - v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar))) - - for i in v_children_range(n) - @inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH]) - end - n - else # constant enode - VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]) + isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false) + + args = iscall(e) ? arguments(e) : children(e) + ar = length(args) + n = v_new(ar) + v_set_flag!(n, VECEXPR_FLAG_ISTREE) + iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL) + h = iscall(e) ? operation(e) : head(e) + v_set_head!(n, add_constant!(g, h)) + # get the signature from op and arity + v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar))) + for i in v_children_range(n) + @inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH]) end - id = add!(g, n, false) - return id + + add!(g, n, false) end """ @@ -431,10 +407,10 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp while !isempty(g.pending) || !isempty(g.analysis_pending) while !isempty(g.pending) (node::VecExpr, eclass_id::Id) = pop!(g.pending) + node = copy(node) canonicalize!(g, node) - if haskey(g.memo, node) - old_class_id = g.memo[node] - g.memo[node] = eclass_id + old_class_id = get!(g.memo, node, eclass_id) + if old_class_id != eclass_id did_something = union!(g, old_class_id, eclass_id) # TODO unique! can node dedup be moved here? compare performance # did_something && unique!(g[eclass_id].nodes) @@ -474,9 +450,8 @@ function check_memo(g::EGraph)::Bool for (id, class) in g.classes @assert id.val == class.id for node in class.nodes - if haskey(test_memo, node) - old_id = test_memo[node] - test_memo[node] = id.val + old_id = get!(test_memo, node, id.val) + if old_id != id.val @assert find(g, old_id) == find(g, id.val) "Unexpected equivalence $node $(g[find(g, id.val)].nodes) $(g[find(g, old_id)].nodes)" end end diff --git a/src/EGraphs/uniquequeue.jl b/src/EGraphs/uniquequeue.jl index aade15d6..512bb61a 100644 --- a/src/EGraphs/uniquequeue.jl +++ b/src/EGraphs/uniquequeue.jl @@ -30,4 +30,4 @@ function Base.pop!(uq::UniqueQueue{T}) where {T} v end -Base.isempty(uq::UniqueQueue) = isempty(uq.vec) \ No newline at end of file +Base.isempty(uq::UniqueQueue) = isempty(uq.vec) diff --git a/src/vecexpr.jl b/src/vecexpr.jl index 1e8a83ba..c18059f9 100644 --- a/src/vecexpr.jl +++ b/src/vecexpr.jl @@ -80,7 +80,7 @@ end """The hash of the e-node.""" @inline v_hash(n::VecExpr)::Id = @inbounds n.data[1] -Base.hash(n::VecExpr) = v_hash(n) # IdKey not necessary here +Base.hash(n::VecExpr, h::UInt) = hash(v_hash(n), h) # IdKey not necessary here Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end]) """Set e-node hash to zero."""