From d4dbfab4c1a3790ab9d29346dfc3fc8d5df9fa00 Mon Sep 17 00:00:00 2001 From: jverzani Date: Fri, 2 Aug 2024 21:26:43 -0400 Subject: [PATCH 1/9] setup for new Metatheory --- Project.toml | 6 ++++-- docs/src/index.md | 3 +-- test/extension_tests.jl | 14 +++----------- test/runtests.jl | 2 +- 4 files changed, 9 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index 1c3535b..2947e1d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleExpressions" uuid = "deba94f7-f32a-40ad-b45e-be020a5ded2f" authors = ["jverzani and contributors"] -version = "1.0.20" +version = "1.0.21" [deps] @@ -21,6 +21,7 @@ SimpleExpressionsTermInterfaceExt = "TermInterface" [compat] AbstractTrees = "0.4" +Metatheory = "3" RecipesBase = "1" Roots = "2" SpecialFunctions = "1,2" @@ -29,9 +30,10 @@ julia = "1.9" [extras] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test","Metatheory"] diff --git a/docs/src/index.md b/docs/src/index.md index 1e5e419..303140c 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -128,8 +128,7 @@ u = D(exp(x) * (sin(3x) + sin(101x))) No simplification is done so the expressions can quickly become unwieldy. There is an extension for `TermInterface` so rewriting of expressions, as is possible with the `Metatheory.jl` package is possible. For example, this pattern can factor out `exp(x)`: -``` -# @example expressions waiting on new Metatheory release +```@example expressions using Metatheory r = @rule (~x * ~a + ~x * ~b --> ~x * (~a + ~b)) r(u) diff --git a/test/extension_tests.jl b/test/extension_tests.jl index 33dd699..e83d019 100644 --- a/test/extension_tests.jl +++ b/test/extension_tests.jl @@ -3,19 +3,9 @@ using SimpleExpressions @symbolic x p using Metatheory -#= -need the following -[compat] -Metatheory = "3" - -[extras] -Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" - -[targets] -test = ["Test","Metatheory"] -=# +@testset "Metatheory" begin r = @rule sin(2(~x)) --> 2sin(~x)*cos(~x) @test r(sin(2x)) === 2*sin(x) * cos(x) @@ -27,3 +17,5 @@ end @test isequal(rewrite(x + x, t), 2x) @test isequal(rewrite(x/x, t), 1) @test isequal(rewrite(1*x, t), x) + +end diff --git a/test/runtests.jl b/test/runtests.jl index 0c4c110..fba83a1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,4 +4,4 @@ using Test import SimpleExpressions: @symbolic_expression include("basic_tests.jl") -#include("extension_tests.jl") +include("extension_tests.jl") From df82e55f59d657cf7569c169724ec7579bdb9dba Mon Sep 17 00:00:00 2001 From: jverzani Date: Fri, 2 Aug 2024 21:33:07 -0400 Subject: [PATCH 2/9] waiting for release --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2947e1d..121f9f8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SimpleExpressions" uuid = "deba94f7-f32a-40ad-b45e-be020a5ded2f" authors = ["jverzani and contributors"] -version = "1.0.21" +version = "1.0.22" [deps] From 9736d1c17a37dcbfdb54216e0b19ac07f27a651d Mon Sep 17 00:00:00 2001 From: jverzani Date: Sun, 4 Aug 2024 16:59:51 -0400 Subject: [PATCH 3/9] add metatheory simplify; add latexify recipe extension --- Project.toml | 5 +++ ext/SimpleExpressionsLatexifyExt.jl | 10 +++++ ext/SimpleExpressionsMetatheoryExt.jl | 51 ++++++++++++++++++++++++ ext/SimpleExpressionsTermInterfaceExt.jl | 12 +++++- src/SimpleExpressions.jl | 45 ++++++++++++++++++++- 5 files changed, 119 insertions(+), 4 deletions(-) create mode 100644 ext/SimpleExpressionsLatexifyExt.jl create mode 100644 ext/SimpleExpressionsMetatheoryExt.jl diff --git a/Project.toml b/Project.toml index 121f9f8..4461441 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,8 @@ version = "1.0.22" [weakdeps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" +Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -14,6 +16,8 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" [extensions] SimpleExpressionsAbstractTreesExt = "AbstractTrees" +SimpleExpressionsLatexifyExt = "Latexify" +SimpleExpressionsMetatheoryExt = "Metatheory" SimpleExpressionsRecipesBaseExt = "RecipesBase" SimpleExpressionsRootsExt = "Roots" SimpleExpressionsSpecialFunctionsExt = "SpecialFunctions" @@ -21,6 +25,7 @@ SimpleExpressionsTermInterfaceExt = "TermInterface" [compat] AbstractTrees = "0.4" +Latexify = "0.16, 1" Metatheory = "3" RecipesBase = "1" Roots = "2" diff --git a/ext/SimpleExpressionsLatexifyExt.jl b/ext/SimpleExpressionsLatexifyExt.jl new file mode 100644 index 0000000..b72b001 --- /dev/null +++ b/ext/SimpleExpressionsLatexifyExt.jl @@ -0,0 +1,10 @@ +module SimpleExpressionsLatexifyExt + +import SimpleExpressions +import Latexify + +Latexify.@latexrecipe function f(x::SimpleExpressions.AbstractSymbolic) + return string(x) +end + +end diff --git a/ext/SimpleExpressionsMetatheoryExt.jl b/ext/SimpleExpressionsMetatheoryExt.jl new file mode 100644 index 0000000..d043680 --- /dev/null +++ b/ext/SimpleExpressionsMetatheoryExt.jl @@ -0,0 +1,51 @@ +module SimpleExpressionsMetatheoryExt + +import SimpleExpressions +import SimpleExpressions: SymbolicNumber + +using Metatheory +using Metatheory.Library + +function SimpleExpressions.simplify(ex::SimpleExpressions.SymbolicExpression) + + mult_t = @commutative_monoid (*) 1 + plus_t = @commutative_monoid (+) 0 + + add_t = @theory a n m begin + a + a == 2a + n::Number * a + a == (n+1) * a + a + m::Number*a == (1 + m) * a + n::Number * a + m::Number * a == (n + m) * a + end + + minus_t = @theory a b begin + a - a --> 0 + a + (-b) == a - b + end + + mulplus_t = @theory a b c begin + 0 * a --> 0 + a * 0 --> 0 + a * (b + c) == ((a * b) + (a * c)) + a + (b * a) == ((b + 1) * a) + end + + pow_t = @theory x y z n m p q begin + (y^n) * y == y^(n + 1) + x^n * x^m == x^(n + m) + (x * y)^z == x^z * y^z + (x^p)^q == x^(p * q) + x^0 --> 1 + 0^x --> 0 + 1^x --> 1 + x^1 --> x + inv(x) == x^(-1) + end + maths_theory = mult_t ∪ plus_t ∪ add_t ∪ minus_t ∪ mulplus_t ∪ pow_t + + g = EGraph(ex) + saturate!(g, maths_theory) + extract!(g, astsize) +end + +end diff --git a/ext/SimpleExpressionsTermInterfaceExt.jl b/ext/SimpleExpressionsTermInterfaceExt.jl index c08e6d3..06a45fe 100644 --- a/ext/SimpleExpressionsTermInterfaceExt.jl +++ b/ext/SimpleExpressionsTermInterfaceExt.jl @@ -2,7 +2,9 @@ module SimpleExpressionsTermInterfaceExt using SimpleExpressions -import SimpleExpressions: AbstractSymbolic, Symbolic, SymbolicParameter, SymbolicExpression, SymbolicEquation +import SimpleExpressions: AbstractSymbolic, + Symbolic, SymbolicParameter, SymbolicNumber, + SymbolicExpression, SymbolicEquation using TermInterface @@ -21,12 +23,18 @@ TermInterface.iscall(ex::AbstractSymbolic) = false TermInterface.isexpr(::Symbolic) = false TermInterface.isexpr(::SymbolicParameter) = false +TermInterface.isexpr(::SymbolicNumber) = false TermInterface.isexpr(::AbstractSymbolic) = true function TermInterface.maketerm(T::Type{<:AbstractSymbolic}, head, children, metadata) - head(children...) + if isa(head, Symbol) + head == :. && return first(children) + @show head, children, metadata + return 42 + end + head(SimpleExpressions.assymbolic.(children)...) end diff --git a/src/SimpleExpressions.jl b/src/SimpleExpressions.jl index dfba4b1..6cc5f9e 100644 --- a/src/SimpleExpressions.jl +++ b/src/SimpleExpressions.jl @@ -259,7 +259,8 @@ Base.length(X::SymbolicEquation) = 2 ## ---- assymbolic(x::AbstractSymbolic) = x -assymbolic(x::Any) = SymbolicNumber(x) +assymbolic(x::Symbol) = Symbolic(x) +assymbolic(x::Number) = x #SymbolicNumber(x) issymbolic(x::AbstractSymbolic) = true issymbolic(::Any) = false @@ -291,6 +292,15 @@ end operation(x::SymbolicExpression) = x.op operation(::Any) = nothing +## ---- +""" + simplify(ex) + +Simplify expression using `Metatheory.jl` when that package is loaded +""" +simplify(x::AbstractSymbolic) = x # Metatheory.jl extension adds here +simplify(ex::SymbolicEquation) = simplify(ex.lhs) ~ simplify(ex.rhs) + ## ---- Base.show(io::IO, ::MIME"text/plain", x::AbstractSymbolic) = show(io, x) @@ -479,7 +489,7 @@ function _subs(::typeof(Base.broadcasted), args, y, p=nothing) end # only used for domain restrictions -Base.ifelse(p::AbstractSymbolic, a::Real, b::Real) = SymbolicExpression(ifelse, (p,a,b)) +Base.ifelse(p::AbstractSymbolic, a, b) = SymbolicExpression(ifelse, (p,a,b)) ## utils? Base.isequal(x::AbstractSymbolic, y::AbstractSymbolic) = hash(x) == hash(y) @@ -515,6 +525,37 @@ Base.convert(::Type{Expr}, x::SymbolicNumber) = x.x Base.convert(::Type{Expr}, x::SymbolicExpression) = Expr(:call, x.op, convert.(Expr, assymbolic.(x.arguments))...) + +# isless +Base.isless(x::Symbolic, y::Symbolic) = isless(x.x, y.x) +Base.isless(x::Symbolic, y::SymbolicParameter) = isless(x.x, y.x) +Base.isless(x::SymbolicParameter, y::Symbolic) = isless(x.x, y.x) +Base.isless(x::SymbolicParameter, y::SymbolicParameter) = isless(x.x, y.x) + +Base.isless(x::SymbolicNumber, y::AbstractSymbolic) = true +Base.isless(x::AbstractSymbolic, y::SymbolicNumber) = false +Base.isless(x::SymbolicNumber, y::SymbolicNumber) = isless(x.x, y.x) + +Base.isless(x::SymbolicExpression, y::Symbolic) = false +Base.isless(x::Symbolic, y::SymbolicExpression) = !isless(y,x) +Base.isless(x::SymbolicExpression, y::SymbolicParameter) = false +Base.isless(x::SymbolicParameter, y::SymbolicExpression) = !isless(y,x) +Base.isless(x::SymbolicExpression, y::SymbolicNumber) = isless(x.x, y.x) + +op_val(f) = Base.operator_precedence(Symbol(f)) +function Base.isless(x::SymbolicExpression, y::SymbolicExpression) + xo, yo = op_val(operation(x)), op_val(operation(y)) + isless(xo,yo) && return true + isless(yo, xo) && return false + xc, yc = x.arguments, y.arguments + isless(length(xc), length(yc)) && return true + isless(length(yc), length(xc)) && return false + for (cx, cy) ∈ zip(xc, yc) + isless(cx, cy) && return true + isless(cy, cx) && return false + end + false +end ## includes include("scalar-derivative.jl") From a664e96f1fa52e6de2adf455d1df6894f37edc61 Mon Sep 17 00:00:00 2001 From: jverzani Date: Mon, 5 Aug 2024 16:09:09 -0400 Subject: [PATCH 4/9] fix errors foudn by JET --- src/SimpleExpressions.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/SimpleExpressions.jl b/src/SimpleExpressions.jl index 6cc5f9e..70df9e6 100644 --- a/src/SimpleExpressions.jl +++ b/src/SimpleExpressions.jl @@ -282,7 +282,7 @@ function free_symbol(u::SymbolicExpression) a′ = free_symbol(a) isa(a′, Symbolic) && return a′ if isa(a′, SymbolicEquation) - u = free_symbol(a′) + a′′ = free_symbol(a′) isa(a′′, Symbolic) && return a′′ end end @@ -299,7 +299,7 @@ operation(::Any) = nothing Simplify expression using `Metatheory.jl` when that package is loaded """ simplify(x::AbstractSymbolic) = x # Metatheory.jl extension adds here -simplify(ex::SymbolicEquation) = simplify(ex.lhs) ~ simplify(ex.rhs) +simplify(ex::SymbolicEquation) = SymbolicEquation(simplify.(ex)...) ## ---- @@ -528,9 +528,9 @@ Base.convert(::Type{Expr}, x::SymbolicExpression) = # isless Base.isless(x::Symbolic, y::Symbolic) = isless(x.x, y.x) -Base.isless(x::Symbolic, y::SymbolicParameter) = isless(x.x, y.x) -Base.isless(x::SymbolicParameter, y::Symbolic) = isless(x.x, y.x) -Base.isless(x::SymbolicParameter, y::SymbolicParameter) = isless(x.x, y.x) +Base.isless(x::Symbolic, y::SymbolicParameter) = isless(x.x, y.p) +Base.isless(x::SymbolicParameter, y::Symbolic) = isless(x.p, y.x) +Base.isless(x::SymbolicParameter, y::SymbolicParameter) = isless(x.p, y.p) Base.isless(x::SymbolicNumber, y::AbstractSymbolic) = true Base.isless(x::AbstractSymbolic, y::SymbolicNumber) = false @@ -539,8 +539,9 @@ Base.isless(x::SymbolicNumber, y::SymbolicNumber) = isless(x.x, y.x) Base.isless(x::SymbolicExpression, y::Symbolic) = false Base.isless(x::Symbolic, y::SymbolicExpression) = !isless(y,x) Base.isless(x::SymbolicExpression, y::SymbolicParameter) = false -Base.isless(x::SymbolicParameter, y::SymbolicExpression) = !isless(y,x) -Base.isless(x::SymbolicExpression, y::SymbolicNumber) = isless(x.x, y.x) +Base.isless(x::SymbolicParameter, y::SymbolicExpression) = !isless(y, x) +Base.isless(x::SymbolicExpression, y::SymbolicNumber) = false +Base.isless(x::SymbolicNumber, y::SymbolicExpression) = !isless(y,x) op_val(f) = Base.operator_precedence(Symbol(f)) function Base.isless(x::SymbolicExpression, y::SymbolicExpression) From 875a646110b08b2eb57b445d97f912e0cc237df3 Mon Sep 17 00:00:00 2001 From: jverzani Date: Thu, 19 Sep 2024 09:36:33 -0400 Subject: [PATCH 5/9] relax commutative ones --- src/SimpleExpressions.jl | 43 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/src/SimpleExpressions.jl b/src/SimpleExpressions.jl index dfba4b1..49f529f 100644 --- a/src/SimpleExpressions.jl +++ b/src/SimpleExpressions.jl @@ -290,6 +290,16 @@ end operation(x::SymbolicExpression) = x.op operation(::Any) = nothing +arguments(x::SymbolicExpression) = x.arguments +arguments(x::Any) = (x,) + + +function is_operation(f) + ex -> begin + op = operation(ex) + !isnothing(op) && op == f + end +end ## ---- @@ -313,14 +323,23 @@ function Base.show(io::IO, x::SymbolicExpression) show(io, only(arguments)) print(io, ")") else - a, b = arguments + n = length(arguments) + for (i, a) ∈ enumerate(arguments) + isa(a, SymbolicExpression) && a.op ∈ infix_ops && print(io, "(") + show(io, a) + isa(a, SymbolicExpression) && a.op ∈ infix_ops && print(io, ")") + i != n && print(io, " ", broadcast, string(op), " ") + end + #= + a, bs..., c = arguments isa(a, SymbolicExpression) && a.op ∈ infix_ops && print(io, "(") - show(io, first(arguments)) + show(io, a) isa(a, SymbolicExpression) && a.op ∈ infix_ops && print(io, ")") print(io, " ", broadcast, string(op), " ") isa(b, SymbolicExpression) && b.op ∈ infix_ops && print(io, "(") show(io, b) isa(b, SymbolicExpression) && b.op ∈ infix_ops && print(io, ")") + =# end elseif op == ifelse p,a,b = arguments @@ -371,14 +390,25 @@ Base.:-(x::AbstractSymbolic) = SymbolicExpression(-, (x, )) function _commutative_op(op::typeof(+), x, y) iszero(x) && return y iszero(y) && return x - SymbolicExpression(+, _left_right(x,y)) + + (is_operation(+)(x) || is_operation(+)(y)) && + return SymbolicExpression(+, tuplejoin(arguments(x), arguments(y))) + + return SymbolicExpression(+, (x, y)) end function _commutative_op(op::typeof(*), x, y) isone(x) && return y isone(y) && return x (iszero(x) || iszero(y)) && return 0 - SymbolicExpression(*, _left_right(x,y)) + + (is_operation(*)(x) || is_operation(*)(y)) && + return SymbolicExpression(*, tuplejoin(arguments(x), arguments(y))) + + return SymbolicExpression(*, (x, y)) + + +# SymbolicExpression(*, _left_right(x,y)) end # commutative binary; slight canonicalization @@ -515,6 +545,11 @@ Base.convert(::Type{Expr}, x::SymbolicNumber) = x.x Base.convert(::Type{Expr}, x::SymbolicExpression) = Expr(:call, x.op, convert.(Expr, assymbolic.(x.arguments))...) +# tuplejoin (Discourse) +@inline tuplejoin(x) = x +@inline tuplejoin(x, y) = (x..., y...) +@inline tuplejoin(x, y, z...) = (x..., tuplejoin(y, z...)...) + ## includes include("scalar-derivative.jl") From 8757966984c326b69e4c2be72f65dff286ff9560 Mon Sep 17 00:00:00 2001 From: jverzani Date: Wed, 25 Sep 2024 16:23:57 -0400 Subject: [PATCH 6/9] borrow symbolic utils stuff for simplify --- Project.toml | 2 + ext/SimpleExpressionsMetatheoryExt.jl | 225 ++++++++++++++++++++--- ext/SimpleExpressionsTermInterfaceExt.jl | 65 ------- src/SimpleExpressions.jl | 11 +- 4 files changed, 207 insertions(+), 96 deletions(-) delete mode 100644 ext/SimpleExpressionsTermInterfaceExt.jl diff --git a/Project.toml b/Project.toml index 72635bb..2513f7b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["jverzani and contributors"] version = "1.0.22" [deps] +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" [weakdeps] @@ -24,6 +25,7 @@ SimpleExpressionsSpecialFunctionsExt = "SpecialFunctions" [compat] AbstractTrees = "0.4" +Combinatorics = "1" Latexify = "0.16, 1" Metatheory = "3" RecipesBase = "1" diff --git a/ext/SimpleExpressionsMetatheoryExt.jl b/ext/SimpleExpressionsMetatheoryExt.jl index d043680..6a0b5e7 100644 --- a/ext/SimpleExpressionsMetatheoryExt.jl +++ b/ext/SimpleExpressionsMetatheoryExt.jl @@ -2,50 +2,215 @@ module SimpleExpressionsMetatheoryExt import SimpleExpressions import SimpleExpressions: SymbolicNumber +import SimpleExpressions: permutations, combinations using Metatheory using Metatheory.Library -function SimpleExpressions.simplify(ex::SimpleExpressions.SymbolicExpression) +# Modified from MIT licensed SymbolicUtils.jl +# https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rule.jl +struct ACRule{F,R} + sets::F + rule::R + arity::Int +end + +macro acrule(expr) + arity = length(expr.args[2].args[2:end]) + quote + ACRule(permutations, $(esc(:(@rule($(expr))))), $arity) + end +end + +macro ordered_acrule(expr) + arity = length(expr.args[2].args[2:end]) + quote + ACRule(combinations, $(esc(:(@rule($(expr))))), $arity) + end +end + +function (acr::ACRule)(term) + r = acr.rule + if !iscall(term) + r(term) + else + f = operation(term) + # # Assume that the matcher was formed by closing over a term + # if f != operation(r.lhs) # Maybe offer a fallback if m.term errors. + # return nothing + # end + + args = arguments(term) + + itr = acr.sets(eachindex(args), acr.arity) + + for inds in itr + result = r(f(args[inds]...)) #Term{T}(f, @views args[inds])) + if result !== nothing + # Assumption: inds are unique + length(args) == length(inds) && return result + return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], nothing) # metadata(term)) + end + end + end +end + +function SimpleExpressions.expand(ex::SimpleExpressions.SymbolicExpression) + + _expand_minus = @theory a b xs begin + -(a + b) => -a + -b + -1*(+(xs...)) => +(-xs...) + a - a => 0 + -a + a => 0 + end + - mult_t = @commutative_monoid (*) 1 - plus_t = @commutative_monoid (+) 0 - add_t = @theory a n m begin - a + a == 2a - n::Number * a + a == (n+1) * a - a + m::Number*a == (1 + m) * a - n::Number * a + m::Number * a == (n + m) * a + _expand_distributive = @theory x y z xs ys begin + z*(x + y) => z*x + z*y + (x + y) * z => z*x + z*y + # z * (+(xs...)) => sum(z*x for x in xs) + # +(xs...) * z --> + + z*(x - y) => z*x - z*y + (x - y) * z => z*x - z*y + + end + + _expand_binom = @theory x y n begin + (x + y)^1 => x + y + (x + y)^2 => x^2 + 2*x*y + y^2 + (x + y)^n::isinteger => sum(binomial(Int(n), k) * x^k * y^(n-k) for k in 0:Int(n)) end - minus_t = @theory a b begin - a - a --> 0 - a + (-b) == a - b + _expand_trig = @theory a b begin + sin(2a) => 2sin(a)*cos(a) + sin(a + b) => sin(a)*cos(b) + cos(a)*sin(b) + cos(2a) => cos(a)^2 - sin(a)^2 + cos(a + b) => cos(a)*cos(b) - sin(a)*sin(b) + sec(a) => 1 / cos(a) + csc(a) => 1 / sin(a) + tan(a) => sin(a)/cos(a) + cot(a) => cos(a)/sin(a) end - mulplus_t = @theory a b c begin - 0 * a --> 0 - a * 0 --> 0 - a * (b + c) == ((a * b) + (a * c)) - a + (b * a) == ((b + 1) * a) + + _expand_power = @theory x y a b begin + x^(a+b) => x^a*x^b + (x*y)^a => x^a * y^a + end + _expand_log = @theory x y n begin + log(x*y) => log(x) + log(y) + log(x^n) => n * log(x) end - pow_t = @theory x y z n m p q begin - (y^n) * y == y^(n + 1) - x^n * x^m == x^(n + m) - (x * y)^z == x^z * y^z - (x^p)^q == x^(p * q) - x^0 --> 1 - 0^x --> 0 - 1^x --> 1 - x^1 --> x - inv(x) == x^(-1) + _expand_misc = @theory a b begin + -a => (-1)*a + (1/a) * a => 1 + a * (1/a) => 1 + /(a,b) => *(a, b^(-1)) end - maths_theory = mult_t ∪ plus_t ∪ add_t ∪ minus_t ∪ mulplus_t ∪ pow_t - g = EGraph(ex) - saturate!(g, maths_theory) - extract!(g, astsize) + t = reduce(∪, ( + _expand_minus, + _expand_distributive, _expand_binom, _expand_trig, + _expand_power, _expand_log, + _expand_misc)) + + Metatheory.rewrite(ex, t) +end + + +function SimpleExpressions.simplify(ex::SimpleExpressions.SymbolicExpression) + + PLUS_DISTRIBUTE = [ + @acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...)) + @acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...)) + ] + + CANONICALIZE_TIMES = [ + #@rule(~x::isnotflat(*) => flatten_term(*, ~x)) + #@rule(~x::needs_sorting(*) => sort_args(*, ~x)) + + # @ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b) + # @rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...)) + + @acrule((~y)^(~n) * ~y => (~y)^(~n+1)) + + @ordered_acrule((~z::isone * ~x) => ~x) + @ordered_acrule((~z::iszero * ~x) => ~z) + @rule(*(~x) => ~x) + ] + + MUL_DISTRIBUTE = @ordered_acrule((~x)^(~n) * (~x)^(~m) => (~x)^(~n + ~m)) + + CANONICALIZE_POW = [ + @rule(^(*(~~x), ~y::isinteger) => *(map(a->pow(a, ~y), ~~x)...)) + @rule((((~x)^(~p::isinteger))^(~q::isinteger)) => (~x)^((~p)*(~q))) + @rule(^(~x, ~z::iszero) => 1) + @rule(^(~x, ~z::isone) => ~x) + @rule(inv(~x) => 1/(~x)) + ] + + POW_RULES = [ + @rule(^(~x::isone, ~z) => 1) + ] + + ASSORTED_RULES = [ + @rule(identity(~x) => ~x) + @rule(-(~x) => -1*~x) + @rule(-(~x, ~y) => ~x + -1(~y)) + @rule(~x::isone \ ~y => ~y) + @rule(~x \ ~y => ~y / (~x)) + @rule(one(~x) => 1) #one(symtype(~x))) + @rule(zero(~x) => 0) #zero(symtype(~x))) + @rule(conj(~x::isreal) => ~x) + @rule(real(~x::isreal) => ~x) + @rule(imag(~x::isreal) => 0)#zero(symtype(~x))) + # @rule(ifelse(~x::is_literal_number, ~y, ~z) => ~x ? ~y : ~z) + @rule(ifelse(~x, ~y, ~y) => ~y) + ] + + TRIG_EXP_RULES = [ + # @acrule(~r*~x::has_trig_exp + ~r*~y => ~r*(~x + ~y)) + # @acrule(~r*~x::has_trig_exp + -1*~r*~y => ~r*(~x - ~y)) + @acrule(sin(~x)^2 + cos(~x)^2 => one(~x)) + @acrule(sin(~x)^2 + -1 => -1*cos(~x)^2) + @acrule(cos(~x)^2 + -1 => -1*sin(~x)^2) + + @acrule(cos(~x)^2 + -1*sin(~x)^2 => cos(2 * ~x)) + @acrule(sin(~x)^2 + -1*cos(~x)^2 => -cos(2 * ~x)) + @acrule(cos(~x) * sin(~x) => sin(2 * ~x)/2) + + @acrule(tan(~x)^2 + -1*sec(~x)^2 => one(~x)) + @acrule(-1*tan(~x)^2 + sec(~x)^2 => one(~x)) + @acrule(tan(~x)^2 + 1 => sec(~x)^2) + @acrule(sec(~x)^2 + -1 => tan(~x)^2) + + @acrule(cot(~x)^2 + -1*csc(~x)^2 => one(~x)) + @acrule(cot(~x)^2 + 1 => csc(~x)^2) + @acrule(csc(~x)^2 + -1 => cot(~x)^2) + + @acrule(cosh(~x)^2 + -1*sinh(~x)^2 => one(~x)) + @acrule(cosh(~x)^2 + -1 => sinh(~x)^2) + @acrule(sinh(~x)^2 + 1 => cosh(~x)^2) + + @acrule(cosh(~x)^2 + sinh(~x)^2 => cosh(2 * ~x)) + @acrule(cosh(~x) * sinh(~x) => sinh(2 * ~x)/2) + + @acrule(exp(~x) * exp(~y) => _iszero(~x + ~y) ? 1 : exp(~x + ~y)) + @rule(exp(~x)^(~y) => exp(~x * ~y)) + ] + + t = vcat(PLUS_DISTRIBUTE, + MUL_DISTRIBUTE, + CANONICALIZE_POW, + POW_RULES, + ASSORTED_RULES, + TRIG_EXP_RULES) + + rewrite(ex, t) + end end diff --git a/ext/SimpleExpressionsTermInterfaceExt.jl b/ext/SimpleExpressionsTermInterfaceExt.jl deleted file mode 100644 index 06a45fe..0000000 --- a/ext/SimpleExpressionsTermInterfaceExt.jl +++ /dev/null @@ -1,65 +0,0 @@ -module SimpleExpressionsTermInterfaceExt - -using SimpleExpressions - -import SimpleExpressions: AbstractSymbolic, - Symbolic, SymbolicParameter, SymbolicNumber, - SymbolicExpression, SymbolicEquation - -using TermInterface - -#In other symbolic expression languages, such as SymbolicUtils.jl, the head of a node can correspond to operation and children can correspond to arguments. - -TermInterface.head(ex::SymbolicExpression) = operation(ex) -TermInterface.children(ex::SymbolicExpression) = arguments(ex) - -TermInterface.operation(X::SymbolicExpression) = X.op -TermInterface.arguments(X::SymbolicExpression) = collect(X.arguments) - - -TermInterface.iscall(ex::SymbolicExpression) = true -TermInterface.iscall(ex::AbstractSymbolic) = false - - -TermInterface.isexpr(::Symbolic) = false -TermInterface.isexpr(::SymbolicParameter) = false -TermInterface.isexpr(::SymbolicNumber) = false -TermInterface.isexpr(::AbstractSymbolic) = true - - - -function TermInterface.maketerm(T::Type{<:AbstractSymbolic}, head, children, metadata) - if isa(head, Symbol) - head == :. && return first(children) - @show head, children, metadata - return 42 - end - head(SimpleExpressions.assymbolic.(children)...) -end - - -TermInterface.arity(::AbstractSymbolic) = 0 -TermInterface.arity(ex::SymbolicExpression) = length(ex.arguments) - -TermInterface.metadata(::AbstractSymbolic) = nothing - - -# convert from Expression to SimpleExpression -# all variables become `𝑥` except `p` becomes `𝑝`, a parameter -function SimpleExpressions.assymbolic(x::Expr) - body = _assymbolic(x) - eval(body) -end - -function _assymbolic(x) - if !TermInterface.istree(x) - isa(x, Symbol) && return x == :p ? :(SymbolicParameter(:𝑝)) : :(Symbolic(:𝑥)) - return x - end - - op = TermInterface.operation(x) - arguments = TermInterface.arguments(x) - Expr(:call, op, _assymbolic.(arguments)...) -end - -end diff --git a/src/SimpleExpressions.jl b/src/SimpleExpressions.jl index c45131c..a86b5c4 100644 --- a/src/SimpleExpressions.jl +++ b/src/SimpleExpressions.jl @@ -7,6 +7,7 @@ $(joinpath(@__DIR__, "..", "README.md") |> """ module SimpleExpressions +using Combinatorics using TermInterface export @symbolic @@ -268,7 +269,7 @@ Base.length(X::SymbolicEquation) = 2 TermInterface.operation(x::AbstractSymbolic) = nothing TermInterface.operation(x::SymbolicExpression) = x.op TermInterface.arguments(x::AbstractSymbolic) = nothing -TermInterface.arguments(x::SymbolicExpression) = x.arguments +TermInterface.arguments(x::SymbolicExpression) = collect(x.arguments) TermInterface.head(ex::SymbolicExpression) = operation(ex) TermInterface.children(ex::SymbolicExpression) = arguments(ex) @@ -341,6 +342,14 @@ Simplify expression using `Metatheory.jl` when that package is loaded simplify(x::AbstractSymbolic) = x # Metatheory.jl extension adds here simplify(ex::SymbolicEquation) = SymbolicEquation(simplify.(ex)...) +""" + expand(ex) + +Expand expression using `Metatheory.jl` when that package is loaded +""" +expand(x::AbstractSymbolic) = x # Metatheory.jl extension adds here +expand(ex::SymbolicEquation) = SymbolicEquation(expand.(ex)...) + ## ---- Base.show(io::IO, ::MIME"text/plain", x::AbstractSymbolic) = show(io, x) From ffc7d9fbc6f1a1ff382c3ccfc7f7c19287bba5fa Mon Sep 17 00:00:00 2001 From: jverzani Date: Wed, 25 Sep 2024 16:25:35 -0400 Subject: [PATCH 7/9] adjust bound --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2513f7b..c717079 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ SimpleExpressionsSpecialFunctionsExt = "SpecialFunctions" AbstractTrees = "0.4" Combinatorics = "1" Latexify = "0.16, 1" -Metatheory = "3" +Metatheory = "2,3" RecipesBase = "1" Roots = "2" SpecialFunctions = "1,2" From 264891e5df8c03b3b87f32d1c41485ce0a204754 Mon Sep 17 00:00:00 2001 From: jverzani Date: Wed, 25 Sep 2024 16:27:05 -0400 Subject: [PATCH 8/9] adjust bound --- .#Project.toml | 1 + 1 file changed, 1 insertion(+) create mode 120000 .#Project.toml diff --git a/.#Project.toml b/.#Project.toml new file mode 120000 index 0000000..c6422bb --- /dev/null +++ b/.#Project.toml @@ -0,0 +1 @@ +jverzani@john-verzanis-macbook-pro.local.56759 \ No newline at end of file From 68b9685631713e288b6bb2047a0df2dd30613fc7 Mon Sep 17 00:00:00 2001 From: jverzani Date: Wed, 25 Sep 2024 16:30:08 -0400 Subject: [PATCH 9/9] compat --- .#Project.toml | 1 - Project.toml | 2 +- test/runtests.jl | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) delete mode 120000 .#Project.toml diff --git a/.#Project.toml b/.#Project.toml deleted file mode 120000 index c6422bb..0000000 --- a/.#Project.toml +++ /dev/null @@ -1 +0,0 @@ -jverzani@john-verzanis-macbook-pro.local.56759 \ No newline at end of file diff --git a/Project.toml b/Project.toml index c717079..2513f7b 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ SimpleExpressionsSpecialFunctionsExt = "SpecialFunctions" AbstractTrees = "0.4" Combinatorics = "1" Latexify = "0.16, 1" -Metatheory = "2,3" +Metatheory = "3" RecipesBase = "1" Roots = "2" SpecialFunctions = "1,2" diff --git a/test/runtests.jl b/test/runtests.jl index fba83a1..0c4c110 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,4 +4,4 @@ using Test import SimpleExpressions: @symbolic_expression include("basic_tests.jl") -include("extension_tests.jl") +#include("extension_tests.jl")