Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 30 additions & 52 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,48 +224,32 @@ Returns the canonical e-class id for a given e-class.

@inline Base.getindex(g::EGraph, i::Id) = g.classes[IdKey(find(g, i))]

# function canonicalize(g::EGraph, n::VecExpr)::VecExpr
# if !v_isexpr(n)
# v_hash!(n)
# return n
# end
# l = v_arity(n)
# new_n = v_new(l)
# v_set_flag!(new_n, v_flags(n))
# v_set_head!(new_n, v_head(n))
# for i in v_children_range(n)
# @inbounds new_n[i] = find(g, n[i])
# end
# v_hash!(new_n)
# new_n
# end

function canonicalize!(g::EGraph, n::VecExpr)
# orig = copy(n)
# inmemo = any(entry -> objectid(entry) == objectid(n), keys(g.memo))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gkronber leftover

v_isexpr(n) || @goto ret
for i in (VECEXPR_META_LENGTH + 1):length(n)
@inbounds n[i] = find(g, n[i])
end
v_unset_hash!(n)
@label ret
v_hash!(n)
# @assert orig == n || !inmemo
n
end

function lookup(g::EGraph, n::VecExpr)::Id
canonicalize!(g, n)
h = IdKey(v_hash(n))

haskey(g.memo, n) ? find(g, g.memo[n]) : 0
id = get(g.memo, n, zero(Id))
iszero(id) ? id : find(g, id)
end


function add_class_by_op(g::EGraph, n, eclass_id)
key = IdKey(v_signature(n))
if haskey(g.classes_by_op, key)
push!(g.classes_by_op[key], eclass_id)
else
g.classes_by_op[key] = [eclass_id]
end
vec = get!(g.classes_by_op, key, Vector{Id}())
push!(vec, eclass_id)
end

"""
Expand All @@ -274,7 +258,8 @@ Inserts an e-node in an [`EGraph`](@ref)
function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis}
canonicalize!(g, n)

haskey(g.memo, n) && return g.memo[n]
id = get(g.memo, n, zero(Id))
iszero(id) || return id

if should_copy
n = copy(n)
Expand All @@ -291,7 +276,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
g.memo[n] = id

add_class_by_op(g, n, id)
eclass = EClass{Analysis}(id, VecExpr[n], Pair{VecExpr,Id}[], make(g, n))
eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n))
g.classes[IdKey(id)] = eclass
modify!(g, eclass)
push!(g.pending, n => id)
Expand Down Expand Up @@ -320,28 +305,22 @@ function addexpr!(g::EGraph, se)::Id
se isa EClass && return se.id
e = preprocess(se)

n = if isexpr(e)
args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)

h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))

# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))

for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
n
else # constant enode
VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)])
isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false)

args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
n = v_new(ar)
v_set_flag!(n, VECEXPR_FLAG_ISTREE)
iscall(e) && v_set_flag!(n, VECEXPR_FLAG_ISCALL)
h = iscall(e) ? operation(e) : head(e)
v_set_head!(n, add_constant!(g, h))
# get the signature from op and arity
v_set_signature!(n, hash(maybe_quote_operation(h), hash(ar)))
for i in v_children_range(n)
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end
id = add!(g, n, false)
return id

add!(g, n, false)
end

"""
Expand Down Expand Up @@ -431,10 +410,10 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
while !isempty(g.pending) || !isempty(g.analysis_pending)
while !isempty(g.pending)
(node::VecExpr, eclass_id::Id) = pop!(g.pending)
node = copy(node)
canonicalize!(g, node)
if haskey(g.memo, node)
old_class_id = g.memo[node]
g.memo[node] = eclass_id
old_class_id = get!(g.memo, node, eclass_id)
if old_class_id != eclass_id
did_something = union!(g, old_class_id, eclass_id)
# TODO unique! can node dedup be moved here? compare performance
# did_something && unique!(g[eclass_id].nodes)
Expand Down Expand Up @@ -474,9 +453,8 @@ function check_memo(g::EGraph)::Bool
for (id, class) in g.classes
@assert id.val == class.id
for node in class.nodes
if haskey(test_memo, node)
old_id = test_memo[node]
test_memo[node] = id.val
old_id = get!(test_memo, node, id.val)
if old_id != id.val
@assert find(g, old_id) == find(g, id.val) "Unexpected equivalence $node $(g[find(g, id.val)].nodes) $(g[find(g, old_id)].nodes)"
end
end
Expand Down
13 changes: 11 additions & 2 deletions src/EGraphs/uniquequeue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@ end
UniqueQueue{T}() where {T} = UniqueQueue{T}(Set{T}(), T[])

function Base.push!(uq::UniqueQueue{T}, x::T) where {T}
if !(x in uq.set)
push!(uq.set, x)
# checks if x is contained in s and adds x if it is not, using a single hash call and lookup
# available from Julia 1.11
function in!(x::T, s::Set)
idx, sh = Base.ht_keyindex2_shorthash!(s.dict, x)
idx > 0 && return true
Base._setindex!(s.dict, nothing, x, -idx, sh)

false
end

if !in!(x, uq.set)
push!(uq.vec, x)
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/vecexpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ end

"""The hash of the e-node."""
@inline v_hash(n::VecExpr)::Id = @inbounds n.data[1]
Base.hash(n::VecExpr) = v_hash(n) # IdKey not necessary here
Base.hash(n::VecExpr, h::UInt) = hash(v_hash(n), h) # IdKey not necessary here
Base.:(==)(a::VecExpr, b::VecExpr) = (@view a.data[2:end]) == (@view b.data[2:end])

"""Set e-node hash to zero."""
Expand Down