diff --git a/src/IRMatch.h b/src/IRMatch.h index 4ec8b2694e3f..283d690cecec 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -479,7 +479,12 @@ struct Wild { template std::ostream &operator<<(std::ostream &s, const Wild &op) { - s << "_" << i; + constexpr const char *names[] = {"x", "y", "z", "w", "u", "v"}; + if constexpr (i < std::size(names)) { + s << names[i]; + } else { + s << "_" << i; + } return s; } @@ -704,7 +709,7 @@ struct BinOp { } HALIDE_ALWAYS_INLINE - Expr make(MatcherState &state, halide_type_t type_hint) const noexcept { + Expr make(MatcherState &state, halide_type_t type_hint) const { Expr ea, eb; if (std::is_same_v) { eb = b.make(state, type_hint); @@ -1976,7 +1981,20 @@ struct VectorReduceOp { template inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp &op) { - s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")"; + if constexpr (reduce_op == VectorReduce::Add) { + s << "h_add("; + } else if constexpr (reduce_op == VectorReduce::Min) { + s << "h_min("; + } else if constexpr (reduce_op == VectorReduce::Max) { + s << "h_max("; + } else if constexpr (reduce_op == VectorReduce::And) { + s << "h_and("; + } else if constexpr (reduce_op == VectorReduce::Or) { + s << "h_or("; + } else { + s << "vector_reduce(" << reduce_op << ", "; + } + s << op.a << ", " << op.lanes << ")"; return s; } diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 8fb7554ec84b..b66fcb3f33ae 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -130,7 +130,10 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type); if (rewrite(h_add(x * broadcast(y, arg_lanes), lanes), h_add(x, lanes) * broadcast(y, lanes)) || rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes)) || - rewrite(h_add(broadcast(x, arg_lanes), lanes), broadcast(x * factor, lanes))) { + rewrite(h_add(broadcast(x, arg_lanes), lanes), broadcast(x * factor, lanes)) || + rewrite(h_add(broadcast(x, c0), lanes), broadcast(h_add(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_add(broadcast(x, c0), lanes), broadcast(h_add(x, 1) * cast(op->type.element_of(), (c0 / lanes)), lanes), c0 % lanes == 0) || + false) { return mutate(rewrite.result, info); } break; @@ -142,8 +145,10 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_min(max(x, broadcast(y, arg_lanes)), lanes), max(h_min(x, lanes), broadcast(y, lanes))) || rewrite(h_min(max(broadcast(x, arg_lanes), y), lanes), max(h_min(y, lanes), broadcast(x, lanes))) || rewrite(h_min(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || - rewrite(h_min(broadcast(x, c0), lanes), h_min(x, lanes), factor % c0 == 0) || - rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0)) || + rewrite(h_min(broadcast(x, c0), 1), h_min(x, 1)) || + rewrite(h_min(broadcast(x, c0), lanes), broadcast(h_min(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_min(ramp(x, y, arg_lanes), 1), x + min(y * (arg_lanes - 1), 0)) || + rewrite(h_min(ramp(x, y, arg_lanes), lanes), ramp(x + min(y * (factor - 1), 0), y * factor, lanes)) || false) { return mutate(rewrite.result, info); } @@ -156,8 +161,10 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_max(max(x, broadcast(y, arg_lanes)), lanes), max(h_max(x, lanes), broadcast(y, lanes))) || rewrite(h_max(max(broadcast(x, arg_lanes), y), lanes), max(h_max(y, lanes), broadcast(x, lanes))) || rewrite(h_max(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || - rewrite(h_max(broadcast(x, c0), lanes), h_max(x, lanes), factor % c0 == 0) || - rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0)) || + rewrite(h_max(broadcast(x, c0), 1), h_max(x, 1)) || + rewrite(h_max(broadcast(x, c0), lanes), broadcast(h_max(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_max(ramp(x, y, arg_lanes), 1), x + max(y * (arg_lanes - 1), 0)) || + rewrite(h_max(ramp(x, y, arg_lanes), lanes), ramp(x + max(y * (factor - 1), 0), y * factor, lanes)) || false) { return mutate(rewrite.result, info); } @@ -170,14 +177,15 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_and(x && broadcast(y, arg_lanes), lanes), h_and(x, lanes) && broadcast(y, lanes)) || rewrite(h_and(broadcast(x, arg_lanes) && y, lanes), h_and(y, lanes) && broadcast(x, lanes)) || rewrite(h_and(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || - rewrite(h_and(broadcast(x, c0), lanes), h_and(x, lanes), factor % c0 == 0) || - rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), + rewrite(h_and(broadcast(x, c0), lanes), broadcast(h_and(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_and(broadcast(x, c0), lanes), broadcast(h_and(x, 1), lanes), c0 >= lanes) || + rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), 1), x + max(y * (arg_lanes - 1), 0) < z) || - rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), + rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), 1), x + max(y * (arg_lanes - 1), 0) <= z) || - rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), 1), x < y + min(z * (arg_lanes - 1), 0)) || - rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), 1), x <= y + min(z * (arg_lanes - 1), 0)) || false) { return mutate(rewrite.result, info); @@ -191,15 +199,16 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_or(x && broadcast(y, arg_lanes), lanes), h_or(x, lanes) && broadcast(y, lanes)) || rewrite(h_or(broadcast(x, arg_lanes) && y, lanes), h_or(y, lanes) && broadcast(x, lanes)) || rewrite(h_or(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || - rewrite(h_or(broadcast(x, c0), lanes), h_or(x, lanes), factor % c0 == 0) || + rewrite(h_or(broadcast(x, c0), lanes), broadcast(h_or(x, lanes / c0), c0), lanes % c0 == 0) || + rewrite(h_or(broadcast(x, c0), lanes), broadcast(h_or(x, 1), lanes), c0 >= lanes) || // type of arg_lanes is somewhat indeterminate - rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), + rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), 1), x + min(y * (arg_lanes - 1), 0) < z) || - rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), + rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), 1), x + min(y * (arg_lanes - 1), 0) <= z) || - rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), 1), x < y + max(z * (arg_lanes - 1), 0)) || - rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), 1), x <= y + max(z * (arg_lanes - 1), 0)) || false) { return mutate(rewrite.result, info); diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index de10bde5a1b9..ff934e1b82ba 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -805,11 +805,64 @@ void check_vectors() { check(VectorReduce::make(VectorReduce::And, Broadcast::make(bool_vector, 4), 1), VectorReduce::make(VectorReduce::And, bool_vector, 1)); check(VectorReduce::make(VectorReduce::Or, Broadcast::make(bool_vector, 4), 2), - VectorReduce::make(VectorReduce::Or, bool_vector, 2)); + Broadcast::make(VectorReduce::make(VectorReduce::Or, bool_vector, 1), 2)); check(VectorReduce::make(VectorReduce::Min, Broadcast::make(int_vector, 4), 4), - int_vector); + Broadcast::make(VectorReduce::make(VectorReduce::Min, int_vector, 1), 4)); check(VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8), - VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8)); + Broadcast::make(VectorReduce::make(VectorReduce::Max, int_vector, 2), 4)); + + { + Expr x = Variable::make(Int(32), "x"); + Expr y = Variable::make(Int(32), "y"); + + // == Symbolic Strides == + + // 1. Min: Scalar Reduction (arg_lanes=4, lanes=1 -> factor=4) + check(VectorReduce::make(VectorReduce::Min, Ramp::make(x, y, 4), 1), + min(y, 0) * 3 + x); + + // 2. Min: Vector Reduction (arg_lanes=6, lanes=2 -> factor=3) + check(VectorReduce::make(VectorReduce::Min, Ramp::make(x, y, 6), 2), + Ramp::make(min(y, 0) * 2 + x, y * 3, 2)); + + // 3. Max: Scalar Reduction (arg_lanes=4, lanes=1 -> factor=4) + check(VectorReduce::make(VectorReduce::Max, Ramp::make(x, y, 4), 1), + max(y, 0) * 3 + x); + + // 4. Max: Vector Reduction (arg_lanes=6, lanes=2 -> factor=3) + check(VectorReduce::make(VectorReduce::Max, Ramp::make(x, y, 6), 2), + Ramp::make(max(y, 0) * 2 + x, y * 3, 2)); + + // == Constant Strides (Positive & Negative) == + + // 5. Min: Positive Stride (arg_lanes=8, lanes=2 -> factor=4, stride=2) + // Block 1: min(x, x+2, x+4, x+6) -> x + // Expected Base: x + min(2 * 3, 0) -> x + 0 -> x + // Expected Stride: 2 * 4 = 8 + check(VectorReduce::make(VectorReduce::Min, Ramp::make(x, 2, 8), 2), + Ramp::make(x, 8, 2)); + + // 6. Max: Positive Stride (arg_lanes=8, lanes=2 -> factor=4, stride=2) + // Block 1: max(x, x+2, x+4, x+6) -> x+6 + // Expected Base: x + max(2 * 3, 0) -> x + 6 + // Expected Stride: 2 * 4 = 8 + check(VectorReduce::make(VectorReduce::Max, Ramp::make(x, 2, 8), 2), + Ramp::make(x + 6, 8, 2)); + + // 7. Min: Negative Stride (arg_lanes=8, lanes=2 -> factor=4, stride=-2) + // Block 1: min(x, x-2, x-4, x-6) -> x-6 + // Expected Base: x + min(-2 * 3, 0) -> x - 6 + // Expected Stride: -2 * 4 = -8 + check(VectorReduce::make(VectorReduce::Min, Ramp::make(x, -2, 8), 2), + Ramp::make(x + -6, -8, 2)); + + // 8. Max: Negative Stride (arg_lanes=8, lanes=2 -> factor=4, stride=-2) + // Block 1: max(x, x-2, x-4, x-6) -> x + // Expected Base: x + max(-2 * 3, 0) -> x + 0 -> x + // Expected Stride: -2 * 4 = -8 + check(VectorReduce::make(VectorReduce::Max, Ramp::make(x, -2, 8), 2), + Ramp::make(x, -8, 2)); + } { // h_add(broadcast(x, 8), 4) should simplify to broadcast(x * 2, 4) diff --git a/test/fuzz/random_expr_generator.h b/test/fuzz/random_expr_generator.h index abb10351260f..7cc556efac3b 100644 --- a/test/fuzz/random_expr_generator.h +++ b/test/fuzz/random_expr_generator.h @@ -363,12 +363,10 @@ class RandomExpressionGenerator { int factor = fuzz.ConsumeIntegralInRange(2, 4); int input_lanes = t.lanes() * factor; if (input_lanes <= 32) { - VectorReduce::Operator ops[] = { - VectorReduce::Add, - VectorReduce::Min, - VectorReduce::Max, - }; - auto op = fuzz.PickValueInArray(ops); + auto op = + t.is_bool() ? + fuzz.PickValueInArray({VectorReduce::And, VectorReduce::Or}) : + fuzz.PickValueInArray({VectorReduce::Add, VectorReduce::Min, VectorReduce::Max}); Expr val = random_expr(t.with_lanes(input_lanes), depth); internal_assert(val.type().lanes() == input_lanes) << val; return VectorReduce::make(op, val, t.lanes()); diff --git a/test/fuzz/simplify.cpp b/test/fuzz/simplify.cpp index 2ddc82e6913d..49f0530f978e 100644 --- a/test/fuzz/simplify.cpp +++ b/test/fuzz/simplify.cpp @@ -12,19 +12,55 @@ using std::string; using namespace Halide; using namespace Halide::Internal; +struct SimpilfyResult : public std::variant { + using std::variant::variant; + bool ok() const { + return index() == 0; + } + bool failed() const { + return index() == 1; + } + operator Expr() const { + return std::get(*this); + } +}; + +SimpilfyResult safe_simplify(const Expr &e) { + try { + return simplify(e); + } catch (InternalError &err) { + std::cerr << "Simplifier failed to simplify expression:\n" + << e << "\n"; + std::cerr << err.what() << "\n"; + return err; + } +} + bool test_simplification(Expr a, Expr b, const map &vars) { if (equal(a, b) && !a.same_as(b)) { std::cerr << "Simplifier created new IR node but made no changes:\n" << a << "\n"; return false; } - if (Expr sb = simplify(b); !equal(b, sb)) { + SimpilfyResult sb = safe_simplify(b); + if (sb.failed() || !equal(b, (Expr)sb)) { // Test all sub-expressions in pre-order traversal to minimize bool found_failure = false; mutate_with(a, [&](auto *self, const Expr &e) { self->mutate_base(e); - Expr s = simplify(e); - Expr ss = simplify(s); + Expr s, ss; + if (SimpilfyResult res = safe_simplify(e); res.ok()) { + s = res; + } else { + found_failure = true; + return e; + } + if (SimpilfyResult res = safe_simplify(s); res.ok()) { + ss = res; + } else { + found_failure = true; + return e; + } if (!found_failure && !equal(s, ss)) { std::cerr << "Idempotency failure\n " << e << "\n -> " @@ -34,10 +70,10 @@ bool test_simplification(Expr a, Expr b, const map &vars) { // added to the simplifier to debug the failure. std::cerr << "---------------------------------\n" << "Begin simplification of original:\n" - << simplify(e) << "\n"; + << s << "\n"; std::cerr << "---------------------------------\n" << "Begin resimplification of result:\n" - << simplify(s) << "\n" + << ss << "\n" << "---------------------------------\n"; found_failure = true; @@ -47,8 +83,20 @@ bool test_simplification(Expr a, Expr b, const map &vars) { return false; } - Expr a_v = simplify(substitute(vars, a)); - Expr b_v = simplify(substitute(vars, b)); + Expr a_v = substitute(vars, a); + if (SimpilfyResult res = safe_simplify(a_v); res.ok()) { + a_v = res; + } else { + return false; + } + + Expr b_v = substitute(vars, b); + if (SimpilfyResult res = safe_simplify(b_v); res.ok()) { + b_v = res; + } else { + return false; + } + // If the simplifier didn't produce constants, there must be // undefined behavior in this expression. Ignore it. if (!Internal::is_const(a_v) || !Internal::is_const(b_v)) { @@ -72,7 +120,12 @@ bool test_simplification(Expr a, Expr b, const map &vars) { } bool test_expression(RandomExpressionGenerator ®, Expr test, int samples) { - Expr simplified = simplify(test); + Expr simplified; + if (SimpilfyResult res = safe_simplify(test); res.ok()) { + simplified = res; + } else { + return false; + } map vars; for (const auto &fuzz_var : reg.fuzz_vars) { @@ -97,16 +150,20 @@ bool test_expression(RandomExpressionGenerator ®, Expr test, int samples) { return true; } -Expr simplify_at_depth(int limit, const Expr &in) { - return mutate_with(in, [&](auto *self, const Expr &e) { - if (limit == 0) { - return simplify(e); - } - limit--; - Expr new_e = self->mutate_base(e); - limit++; - return new_e; - }); +SimpilfyResult simplify_at_depth(int limit, const Expr &in) { + try { + return mutate_with(in, [&](auto *self, const Expr &e) { + if (limit == 0) { + return simplify(e); + } + limit--; + Expr new_e = self->mutate_base(e); + limit++; + return new_e; + }); + } catch (InternalError &err) { + return err; + } } } // namespace @@ -125,7 +182,7 @@ FUZZ_TEST(simplify, FuzzingContext &fuzz) { // FIXME: These need to be disabled (otherwise crashes and/or failures): // reg.gen_ramp_of_vector = false; // reg.gen_broadcast_of_vector = false; - reg.gen_vector_reduce = false; + // reg.gen_vector_reduce = false; reg.gen_reinterpret = false; reg.gen_shuffles = false; @@ -140,10 +197,16 @@ FUZZ_TEST(simplify, FuzzingContext &fuzz) { self->mutate_base(e); if (e.type().bits() && !found_failure) { for (int i = 1; i < 4 && !found_failure; i++) { - Expr limited = simplify_at_depth(i, e); - found_failure = !test_expression(reg, limited, samples_during_minimization); - if (found_failure) { - return limited; + SimpilfyResult limited_res = simplify_at_depth(i, e); + if (limited_res.failed()) { + found_failure = true; + return e; + } else { + Expr limited = limited_res; + found_failure = !test_expression(reg, limited, samples_during_minimization); + if (found_failure) { + return limited; + } } } if (!found_failure) {