diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index 6df39e44..d0a24ead 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -618,11 +618,18 @@ class Evaluator( checkStackDepth(e.pos, e) try { val lhs = visitExpr(e.value) + // Auto-TCO'd calls should normally be intercepted by visitExprWithTailCallSupport, + // but we handle them defensively here to preserve lazy semantics if this path is ever reached. implicit val tailstrictMode: TailstrictMode = - if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + if (!e.strict && e.tailstrict) TailstrictModeAutoTCO + else if (e.tailstrict) TailstrictModeEnabled + else TailstrictModeDisabled if (e.tailstrict) { - TailCall.resolve(lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)) + val args: Array[Eval] = + if (!e.strict) e.args.map(visitAsLazy(_)) + else e.args.map(visitExpr(_)).asInstanceOf[Array[Eval]] + TailCall.resolve(lhs.cast[Val.Func].apply(args, e.namedNames, e.pos)) } else { lhs.cast[Val.Func].apply(e.args.map(visitAsLazy(_)), e.namedNames, e.pos) } @@ -635,7 +642,9 @@ class Evaluator( try { val lhs = visitExpr(e.value) implicit val tailstrictMode: TailstrictMode = - if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + if (!e.strict && e.tailstrict) TailstrictModeAutoTCO + else if (e.tailstrict) TailstrictModeEnabled + else TailstrictModeDisabled if (e.tailstrict) { TailCall.resolve(lhs.cast[Val.Func].apply0(e.pos)) } else { @@ -650,9 +659,12 @@ class Evaluator( try { val lhs = visitExpr(e.value) implicit val tailstrictMode: TailstrictMode = - if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + if (!e.strict && e.tailstrict) TailstrictModeAutoTCO + else if (e.tailstrict) TailstrictModeEnabled + else TailstrictModeDisabled if (e.tailstrict) { - TailCall.resolve(lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)) + val arg: Eval = if (!e.strict) visitAsLazy(e.a1) else visitExpr(e.a1) + TailCall.resolve(lhs.cast[Val.Func].apply1(arg, e.pos)) } else { val l1 = visitAsLazy(e.a1) lhs.cast[Val.Func].apply1(l1, e.pos) @@ -666,10 +678,18 @@ class Evaluator( try { val lhs = visitExpr(e.value) implicit val tailstrictMode: TailstrictMode = - if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + if (!e.strict && e.tailstrict) TailstrictModeAutoTCO + else if (e.tailstrict) TailstrictModeEnabled + else TailstrictModeDisabled if (e.tailstrict) { - TailCall.resolve(lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)) + if (!e.strict) { + TailCall.resolve( + lhs.cast[Val.Func].apply2(visitAsLazy(e.a1), visitAsLazy(e.a2), e.pos) + ) + } else { + TailCall.resolve(lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)) + } } else { val l1 = visitAsLazy(e.a1) val l2 = visitAsLazy(e.a2) @@ -684,12 +704,22 @@ class Evaluator( try { val lhs = visitExpr(e.value) implicit val tailstrictMode: TailstrictMode = - if (e.tailstrict) TailstrictModeEnabled else TailstrictModeDisabled + if (!e.strict && e.tailstrict) TailstrictModeAutoTCO + else if (e.tailstrict) TailstrictModeEnabled + else TailstrictModeDisabled if (e.tailstrict) { - TailCall.resolve( - lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos) - ) + if (!e.strict) { + TailCall.resolve( + lhs + .cast[Val.Func] + .apply3(visitAsLazy(e.a1), visitAsLazy(e.a2), visitAsLazy(e.a3), e.pos) + ) + } else { + TailCall.resolve( + lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos) + ) + } } else { val l1 = visitAsLazy(e.a1) val l2 = visitAsLazy(e.a2) @@ -1158,13 +1188,36 @@ class Evaluator( } } + // And/Or rhs tail-position helpers — extracted to preserve @tailrec on visitExprWithTailCallSupport. + // TailCall sentinels pass through without a boolean type check: this is a deliberate semantic + // relaxation matching google/jsonnet behavior (where `&&` is simply `if a then b else false`). + // Direct non-boolean rhs values (e.g. `true && "hello"`) are still caught. + private def visitAndRhsTailPos(rhs: Expr, andPos: Position)(implicit scope: ValScope): Val = { + visitExprWithTailCallSupport(rhs) match { + case b: Val.Bool => b + case tc: TailCall => tc + case unknown => + Error.fail(s"binary operator && does not operate on ${unknown.prettyName}s.", andPos) + } + } + + private def visitOrRhsTailPos(rhs: Expr, orPos: Position)(implicit scope: ValScope): Val = { + visitExprWithTailCallSupport(rhs) match { + case b: Val.Bool => b + case tc: TailCall => tc + case unknown => + Error.fail(s"binary operator || does not operate on ${unknown.prettyName}s.", orPos) + } + } + /** * Evaluate an expression with tail-call support. When a `tailstrict` call is encountered at a * potential tail position, returns a [[TailCall]] sentinel instead of recursing, enabling * `TailCall.resolve` in `visitApply*` to iterate rather than grow the JVM stack. * * Potential tail positions are propagated through: IfElse (both branches), LocalExpr (returned), - * and AssertExpr (returned). All other expression types delegate to normal `visitExpr`. + * AssertExpr (returned), And (rhs), Or (rhs), and Expr.Error (value). All other expression types + * delegate to normal `visitExpr`. */ @tailrec private def visitExprWithTailCallSupport(e: Expr)(implicit scope: ValScope): Val = e match { @@ -1208,6 +1261,26 @@ class Evaluator( } } visitExprWithTailCallSupport(e.returned) + case e: And => + // rhs of && is in tail position: when lhs is true, rhs is returned directly. + // Type check via helper to preserve @tailrec on this method. + visitExpr(e.lhs) match { + case _: Val.True => visitAndRhsTailPos(e.rhs, e.pos) + case _: Val.False => Val.staticFalse + case unknown => + Error.fail(s"binary operator && does not operate on ${unknown.prettyName}s.", e.pos) + } + case e: Or => + // rhs of || is in tail position: when lhs is false, rhs is returned directly. + // Type check via helper to preserve @tailrec on this method. + visitExpr(e.lhs) match { + case _: Val.True => Val.staticTrue + case _: Val.False => visitOrRhsTailPos(e.rhs, e.pos) + case unknown => + Error.fail(s"binary operator || does not operate on ${unknown.prettyName}s.", e.pos) + } + case e: Expr.Error => + Error.fail(materializeError(visitExpr(e.value)), e.pos) // Tail-position tailstrict calls: match TailstrictableExpr to unify the tailstrict guard, // then dispatch by concrete type. // @@ -1222,33 +1295,41 @@ class Evaluator( e match { case e: Apply => try { + val isStrict = e.isStrict val func = visitExpr(e.value).cast[Val.Func] - new TailCall(func, e.args.map(visitExpr(_)).asInstanceOf[Array[Eval]], e.namedNames, e) + val args: Array[Eval] = + if (!isStrict) e.args.map(visitAsLazy(_)) + else e.args.map(visitExpr(_)).asInstanceOf[Array[Eval]] + new TailCall(func, args, e.namedNames, e, strict = isStrict) } catch Error.withStackFrame(e) case e: Apply0 => try { val func = visitExpr(e.value).cast[Val.Func] - new TailCall(func, Evaluator.emptyLazyArray, null, e) + new TailCall(func, Evaluator.emptyLazyArray, null, e, strict = e.isStrict) } catch Error.withStackFrame(e) case e: Apply1 => try { + val isStrict = e.isStrict val func = visitExpr(e.value).cast[Val.Func] - new TailCall(func, Array[Eval](visitExpr(e.a1)), null, e) + val arg: Eval = if (!isStrict) visitAsLazy(e.a1) else visitExpr(e.a1) + new TailCall(func, Array[Eval](arg), null, e, strict = isStrict) } catch Error.withStackFrame(e) case e: Apply2 => try { + val isStrict = e.isStrict val func = visitExpr(e.value).cast[Val.Func] - new TailCall(func, Array[Eval](visitExpr(e.a1), visitExpr(e.a2)), null, e) + val a1: Eval = if (!isStrict) visitAsLazy(e.a1) else visitExpr(e.a1) + val a2: Eval = if (!isStrict) visitAsLazy(e.a2) else visitExpr(e.a2) + new TailCall(func, Array[Eval](a1, a2), null, e, strict = isStrict) } catch Error.withStackFrame(e) case e: Apply3 => try { + val isStrict = e.isStrict val func = visitExpr(e.value).cast[Val.Func] - new TailCall( - func, - Array[Eval](visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3)), - null, - e - ) + val a1: Eval = if (!isStrict) visitAsLazy(e.a1) else visitExpr(e.a1) + val a2: Eval = if (!isStrict) visitAsLazy(e.a2) else visitExpr(e.a2) + val a3: Eval = if (!isStrict) visitAsLazy(e.a3) else visitExpr(e.a3) + new TailCall(func, Array[Eval](a1, a2, a3), null, e, strict = isStrict) } catch Error.withStackFrame(e) case _ => visitExpr(e) } diff --git a/sjsonnet/src/sjsonnet/Expr.scala b/sjsonnet/src/sjsonnet/Expr.scala index 16d5da0d..2f035c65 100644 --- a/sjsonnet/src/sjsonnet/Expr.scala +++ b/sjsonnet/src/sjsonnet/Expr.scala @@ -52,6 +52,13 @@ trait Expr { */ trait TailstrictableExpr extends Expr { def tailstrict: Boolean + + /** + * True when this call was marked as strict (eager argument evaluation) by an explicit + * `tailstrict` annotation. False when auto-TCO'd (lazy argument evaluation to preserve Jsonnet's + * standard lazy semantics). + */ + def isStrict: Boolean = false } object Expr { @@ -231,25 +238,45 @@ object Expr { value: Expr, args: Array[Expr], namedNames: Array[String], - tailstrict: Boolean) + tailstrict: Boolean, + strict: Boolean = true) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply override def exprErrorString: String = Expr.callTargetName(value) + override def isStrict: Boolean = strict } - final case class Apply0(var pos: Position, value: Expr, tailstrict: Boolean) + final case class Apply0( + var pos: Position, + value: Expr, + tailstrict: Boolean, + strict: Boolean = true) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply0 override def exprErrorString: String = Expr.callTargetName(value) + override def isStrict: Boolean = strict } - final case class Apply1(var pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) + final case class Apply1( + var pos: Position, + value: Expr, + a1: Expr, + tailstrict: Boolean, + strict: Boolean = true) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply1 override def exprErrorString: String = Expr.callTargetName(value) + override def isStrict: Boolean = strict } - final case class Apply2(var pos: Position, value: Expr, a1: Expr, a2: Expr, tailstrict: Boolean) + final case class Apply2( + var pos: Position, + value: Expr, + a1: Expr, + a2: Expr, + tailstrict: Boolean, + strict: Boolean = true) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply2 override def exprErrorString: String = Expr.callTargetName(value) + override def isStrict: Boolean = strict } final case class Apply3( var pos: Position, @@ -257,10 +284,12 @@ object Expr { a1: Expr, a2: Expr, a3: Expr, - tailstrict: Boolean) + tailstrict: Boolean, + strict: Boolean = true) extends TailstrictableExpr { final override private[sjsonnet] def tag = ExprTags.Apply3 override def exprErrorString: String = Expr.callTargetName(value) + override def isStrict: Boolean = strict } final case class ApplyBuiltin( var pos: Position, diff --git a/sjsonnet/src/sjsonnet/ExprTransform.scala b/sjsonnet/src/sjsonnet/ExprTransform.scala index 38d4a051..af56fba2 100644 --- a/sjsonnet/src/sjsonnet/ExprTransform.scala +++ b/sjsonnet/src/sjsonnet/ExprTransform.scala @@ -14,37 +14,37 @@ abstract class ExprTransform { if (x2 eq x) expr else Select(pos, x2, name) - case Apply(pos, x, y, namedNames, tailstrict) => + case Apply(pos, x, y, namedNames, tailstrict, strict) => val x2 = transform(x) val y2 = transformArr(y) if ((x2 eq x) && (y2 eq y)) expr - else Apply(pos, x2, y2, namedNames, tailstrict) + else Apply(pos, x2, y2, namedNames, tailstrict, strict) - case Apply0(pos, x, tailstrict) => + case Apply0(pos, x, tailstrict, strict) => val x2 = transform(x) if (x2 eq x) expr - else Apply0(pos, x2, tailstrict) + else Apply0(pos, x2, tailstrict, strict) - case Apply1(pos, x, y, tailstrict) => + case Apply1(pos, x, y, tailstrict, strict) => val x2 = transform(x) val y2 = transform(y) if ((x2 eq x) && (y2 eq y)) expr - else Apply1(pos, x2, y2, tailstrict) + else Apply1(pos, x2, y2, tailstrict, strict) - case Apply2(pos, x, y, z, tailstrict) => + case Apply2(pos, x, y, z, tailstrict, strict) => val x2 = transform(x) val y2 = transform(y) val z2 = transform(z) if ((x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else Apply2(pos, x2, y2, z2, tailstrict) + else Apply2(pos, x2, y2, z2, tailstrict, strict) - case Apply3(pos, x, y, z, a, tailstrict) => + case Apply3(pos, x, y, z, a, tailstrict, strict) => val x2 = transform(x) val y2 = transform(y) val z2 = transform(z) val a2 = transform(a) if ((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr - else Apply3(pos, x2, y2, z2, a2, tailstrict) + else Apply3(pos, x2, y2, z2, a2, tailstrict, strict) case ApplyBuiltin(pos, func, x, tailstrict) => val x2 = transformArr(x) diff --git a/sjsonnet/src/sjsonnet/StaticOptimizer.scala b/sjsonnet/src/sjsonnet/StaticOptimizer.scala index 16946056..4f338438 100644 --- a/sjsonnet/src/sjsonnet/StaticOptimizer.scala +++ b/sjsonnet/src/sjsonnet/StaticOptimizer.scala @@ -279,10 +279,11 @@ class StaticOptimizer( if (a.namedNames != null) a else a.args.length match { - case 0 => Apply0(a.pos, a.value, a.tailstrict) - case 1 => Apply1(a.pos, a.value, a.args(0), a.tailstrict) - case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1), a.tailstrict) - case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2), a.tailstrict) + case 0 => Apply0(a.pos, a.value, a.tailstrict, a.strict) + case 1 => Apply1(a.pos, a.value, a.args(0), a.tailstrict, a.strict) + case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1), a.tailstrict, a.strict) + case 3 => + Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2), a.tailstrict, a.strict) case _ => a } } @@ -542,4 +543,217 @@ class StaticOptimizer( } Val.bool(pos, if (negate) !result else result) } + + /** + * Auto-TCO: override transformBind to detect self-recursive tail calls in function bodies and + * mark them as `tailstrict = true`. This enables the evaluator's TailCall trampoline for those + * calls, preventing JVM stack overflow on deep recursion — without requiring the user to annotate + * every recursive call site with the `tailstrict` keyword. + * + * Safety: we only mark a call as tailstrict when the call provides exactly as many positional + * arguments as the function declares parameters (no named args). This ensures the evaluator takes + * the "simple" apply path where only the already-evaluated passed arguments are forced, avoiding + * the eager evaluation of default argument expressions that `tailstrict` would otherwise trigger. + * + * Handles both binding forms: + * - `local sum(n, acc) = ...` → Bind(args = Params(...), rhs = body) + * - `local sum = function(n, acc) ...` → Bind(args = null, rhs = Function(params, body)) + * + * @see + * [[TailCall]] for the sentinel value used in the TCO protocol + * @see + * https://github.com/databricks/sjsonnet/issues/623 + */ + override def transformBind(b: Bind): Bind = { + val b2 = super.transformBind(b) + val sv = scope.get(b2.name) + if (sv == null) return b2 + + if (b2.args != null) { + // Direct function binding: local sum(n, acc) = body + // Only auto-TCO if the function has at least one non-recursive exit path. + // This prevents turning trivially infinite recursions (e.g. `f(x) = f(x)`) into + // infinite trampoline loops; without a base case the TailCall trampoline would never + // produce a non-TailCall result. + if (!hasNonRecursiveExit(b2.rhs, b2.name, sv.idx, b2.args.names.length)) return b2 + val newRhs = markTailCalls(b2.rhs, b2.name, sv.idx, b2.args.names.length) + if (newRhs ne b2.rhs) b2.copy(rhs = newRhs) else b2 + } else + b2.rhs match { + case f: Function => + // Function literal binding: local sum = function(n, acc) body + if (!hasNonRecursiveExit(f.body, b2.name, sv.idx, f.params.names.length)) return b2 + val newBody = markTailCalls(f.body, b2.name, sv.idx, f.params.names.length) + if (newBody ne f.body) b2.copy(rhs = f.copy(body = newBody)) else b2 + case _ => b2 + } + } + + /** + * Check whether a function body has at least one code path that does NOT end in a self-recursive + * tail call. Functions with no non-recursive exit (e.g. `f(x) = f(x)`) should NOT be auto-TCO'd + * because the TailCall trampoline would loop forever — the function never produces a base-case + * result. + * + * The analysis mirrors `markTailCalls` traversal: it propagates through IfElse, LocalExpr, + * AssertExpr, And/Or (rhs only), and Expr.Error (value) — the same positions where tail calls are + * detected, returning `true` as soon as any leaf expression is found that is NOT a self-recursive + * call. + */ + private def hasNonRecursiveExit( + body: Expr, + selfName: String, + selfIdx: Int, + paramCount: Int + ): Boolean = { + def isSelfTailCall(e: Expr): Boolean = e match { + case a: Apply0 => + a.value match { + case ValidId(_, n, i) => n == selfName && i == selfIdx && paramCount == 0 + case _ => false + } + case a: Apply1 => + a.value match { + case ValidId(_, n, i) => n == selfName && i == selfIdx && paramCount == 1 + case _ => false + } + case a: Apply2 => + a.value match { + case ValidId(_, n, i) => n == selfName && i == selfIdx && paramCount == 2 + case _ => false + } + case a: Apply3 => + a.value match { + case ValidId(_, n, i) => n == selfName && i == selfIdx && paramCount == 3 + case _ => false + } + case a: Apply => + a.value match { + case ValidId(_, n, i) => + n == selfName && i == selfIdx && a.namedNames == null && a.args.length == paramCount + case _ => false + } + case _ => false + } + + body match { + case e: IfElse => + // Either branch being non-recursive is sufficient + hasNonRecursiveExit(e.`then`, selfName, selfIdx, paramCount) || + (e.`else` != null && hasNonRecursiveExit(e.`else`, selfName, selfIdx, paramCount)) || + e.`else` == null // missing else returns Val.Null when condition is false — a non-recursive exit + case e: LocalExpr => + hasNonRecursiveExit(e.returned, selfName, selfIdx, paramCount) + case e: AssertExpr => + hasNonRecursiveExit(e.returned, selfName, selfIdx, paramCount) + case e: And => + // rhs of && is in tail position (when lhs is true, rhs result is returned directly). + // lhs is NOT in tail position, but provides a control-flow exit path (returns false + // when lhs is false). We must still check rhs for non-recursive exits to prevent + // auto-TCO on functions with no base case (e.g. `f() = true && f()`). + hasNonRecursiveExit(e.rhs, selfName, selfIdx, paramCount) + case e: Or => + // rhs of || is in tail position (when lhs is false, rhs result is returned directly). + // lhs is NOT in tail position, but provides a control-flow exit path (returns true + // when lhs is true). We must still check rhs for non-recursive exits. + hasNonRecursiveExit(e.rhs, selfName, selfIdx, paramCount) + case _: Expr.Error => + // error value is the last thing evaluated before throwing → non-recursive exit + true + case e if isSelfTailCall(e) => + false // this path IS a self-recursive tail call → not a non-recursive exit + case _ => + true // any other expression (literal, non-self call, binary op, etc.) is a base case + } + } + + /** + * Walk an expression tree looking for self-recursive calls in tail position. A call is in tail + * position if it is the last expression evaluated before the function returns — i.e. its result + * becomes the function's return value without further transformation. + * + * Tail position propagates through: + * - Both branches of `if-else` + * - The `returned` expression of `local ... ; returned` + * - The `returned` expression of `assert ... ; returned` + * - The `rhs` of `lhs && rhs` (when lhs is true, rhs is the result) + * - The `rhs` of `lhs || rhs` (when lhs is false, rhs is the result) + * - The `value` of `error value` (last thing evaluated before throwing) + * + * This matches the evaluator's `visitExprWithTailCallSupport` propagation rules exactly. + * + * @param body + * the expression to scan (already transformed by the base optimizer) + * @param selfName + * the name of the function being defined + * @param selfIdx + * the ValScope index of the function binding + * @param paramCount + * the number of parameters the function declares; only calls with exactly this many positional + * args are marked (avoids forcing default arg expressions) + * @return + * the expression with matching tail calls marked `tailstrict = true`, or `body` unchanged if no + * matches found (reference equality preserved) + */ + private def markTailCalls(body: Expr, selfName: String, selfIdx: Int, paramCount: Int): Expr = { + def isSelfCall(value: Expr, callArity: Int): Boolean = value match { + case ValidId(_, name, idx) => name == selfName && idx == selfIdx && callArity == paramCount + case _ => false + } + + body match { + case e: IfElse => + val t = markTailCalls(e.`then`, selfName, selfIdx, paramCount) + val el = + if (e.`else` != null) markTailCalls(e.`else`, selfName, selfIdx, paramCount) else null + if ((t eq e.`then`) && (el eq e.`else`)) body + else IfElse(e.pos, e.cond, t, el) + + case e: LocalExpr => + val ret = markTailCalls(e.returned, selfName, selfIdx, paramCount) + if (ret eq e.returned) body + else LocalExpr(e.pos, e.bindings, ret) + + case e: AssertExpr => + val ret = markTailCalls(e.returned, selfName, selfIdx, paramCount) + if (ret eq e.returned) body + else AssertExpr(e.pos, e.asserted, ret) + + case e: And => + // rhs of && is in tail position: when lhs evaluates to true, rhs is returned directly + val rhs2 = markTailCalls(e.rhs, selfName, selfIdx, paramCount) + if (rhs2 eq e.rhs) body + else And(e.pos, e.lhs, rhs2) + + case e: Or => + // rhs of || is in tail position: when lhs evaluates to false, rhs is returned directly + val rhs2 = markTailCalls(e.rhs, selfName, selfIdx, paramCount) + if (rhs2 eq e.rhs) body + else Or(e.pos, e.lhs, rhs2) + + case e: Expr.Error => + // error value is in tail position (last thing evaluated before throwing) + val v = markTailCalls(e.value, selfName, selfIdx, paramCount) + if (v eq e.value) body + else Expr.Error(e.pos, v) + + // Self-recursive tail calls: mark as tailstrict + strict=false to enable the TailCall trampoline + // while preserving lazy argument semantics (auto-TCO uses TailstrictModeAutoTCO, not + // TailstrictModeEnabled, so arguments are not eagerly forced). + // Only match when call arity == param count (no named args, no default args involved). + case a: Apply0 if !a.tailstrict && isSelfCall(a.value, 0) => + Apply0(a.pos, a.value, tailstrict = true, strict = false) + case a: Apply1 if !a.tailstrict && isSelfCall(a.value, 1) => + Apply1(a.pos, a.value, a.a1, tailstrict = true, strict = false) + case a: Apply2 if !a.tailstrict && isSelfCall(a.value, 2) => + Apply2(a.pos, a.value, a.a1, a.a2, tailstrict = true, strict = false) + case a: Apply3 if !a.tailstrict && isSelfCall(a.value, 3) => + Apply3(a.pos, a.value, a.a1, a.a2, a.a3, tailstrict = true, strict = false) + case a: Apply + if !a.tailstrict && a.namedNames == null && isSelfCall(a.value, a.args.length) => + Apply(a.pos, a.value, a.args, null, tailstrict = true, strict = false) + + case _ => body + } + } } diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index 5ea069ad..ac405584 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -1414,30 +1414,40 @@ object Val { * Using a sealed trait (rather than a plain Boolean) gives the JVM JIT better type-profile * information at `if` guards, and makes the two modes self-documenting at call sites. * - * - [[TailstrictModeEnabled]]: caller will handle TailCall via [[TailCall.resolve]]; sentinels - * may be returned without resolution. + * - [[TailstrictModeEnabled]]: explicit `tailstrict` — caller will handle TailCall via + * [[TailCall.resolve]]; sentinels may be returned without resolution. Arguments are eagerly + * evaluated per the Jsonnet spec. * - [[TailstrictModeDisabled]]: normal call; any TailCall must be resolved before returning. + * - [[TailstrictModeAutoTCO]]: auto-TCO — like Enabled (TailCall sentinels may be returned), but + * arguments are NOT forced, preserving Jsonnet's standard lazy evaluation semantics. */ sealed trait TailstrictMode case object TailstrictModeEnabled extends TailstrictMode case object TailstrictModeDisabled extends TailstrictMode +case object TailstrictModeAutoTCO extends TailstrictMode /** - * Sentinel value for tail call optimization of `tailstrict` calls. When a function body's tail - * position is a `tailstrict` call, the evaluator returns a [[TailCall]] instead of recursing into - * the callee. [[TailCall.resolve]] then re-invokes the target function iteratively, eliminating - * native stack growth. + * Sentinel value for tail call optimization of `tailstrict` and auto-TCO calls. When a function + * body's tail position is a `tailstrict` or auto-TCO call, the evaluator returns a [[TailCall]] + * instead of recursing into the callee. [[TailCall.resolve]] then re-invokes the target function + * iteratively, eliminating native stack growth. * * This is an internal protocol value and must never escape to user-visible code paths (e.g. * materialization, object field access). Every call site that may produce a TailCall must either - * pass `TailstrictModeEnabled` (so the caller resolves it) or guard the result with - * [[TailCall.resolve]]. + * pass `TailstrictModeEnabled` / `TailstrictModeAutoTCO` (so the caller resolves it) or guard the + * result with [[TailCall.resolve]]. + * + * @param strict + * when true, [[TailCall.resolve]] uses [[TailstrictModeEnabled]] (explicit `tailstrict` — eager + * argument forcing per the Jsonnet spec). When false, [[TailstrictModeAutoTCO]] is used so that + * arguments remain lazy (preserving Jsonnet semantics). */ final class TailCall( val func: Val.Func, val args: Array[Eval], val namedNames: Array[String], - val callSiteExpr: Expr) + val callSiteExpr: Expr, + val strict: Boolean = false) extends Val { private[sjsonnet] def valTag: Byte = -1 def pos: Position = callSiteExpr.pos @@ -1449,8 +1459,14 @@ object TailCall { /** * Iteratively resolve a [[TailCall]] chain (trampoline loop). If `current` is not a TailCall, it - * is returned immediately. Otherwise, each TailCall's target function is re-invoked with - * `TailstrictModeEnabled` until a non-TailCall result is produced. + * is returned immediately. Otherwise, each TailCall's target function is re-invoked until a + * non-TailCall result is produced. + * + * The mode used for re-invocation depends on [[TailCall.strict]]: + * - `strict = true` (explicit `tailstrict`): uses [[TailstrictModeEnabled]], which forces eager + * argument evaluation inside `func.apply`. + * - `strict = false` (auto-TCO): uses [[TailstrictModeAutoTCO]], which preserves lazy argument + * evaluation — arguments are only evaluated when the function body accesses them. * * Error frames preserve the original call-site expression name (e.g. "Apply2") so that TCO does * not alter user-visible stack traces. @@ -1458,10 +1474,11 @@ object TailCall { @tailrec def resolve(current: Val)(implicit ev: EvalScope): Val = current match { case tc: TailCall => - implicit val tailstrictMode: TailstrictMode = TailstrictModeEnabled + val mode: TailstrictMode = + if (tc.strict) TailstrictModeEnabled else TailstrictModeAutoTCO val next = try { - tc.func.apply(tc.args, tc.namedNames, tc.callSiteExpr.pos) + tc.func.apply(tc.args, tc.namedNames, tc.callSiteExpr.pos)(ev, mode) } catch { case e: Error => throw e.addFrame(tc.callSiteExpr.pos, tc.callSiteExpr) diff --git a/sjsonnet/test/resources/new_test_suite/auto_tco.jsonnet b/sjsonnet/test/resources/new_test_suite/auto_tco.jsonnet new file mode 100644 index 00000000..4110937a --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/auto_tco.jsonnet @@ -0,0 +1,58 @@ +// Test automatic tail-call optimization (auto-TCO). +// The StaticOptimizer detects self-recursive calls in tail position and +// marks them as tailstrict, enabling the TailCall trampoline without +// requiring the user to write 'tailstrict' at each recursive call site. +// +// sjsonnet's default maxStack is 500, so any recursion deeper than 500 +// would fail with "Max stack frames exceeded" without TCO. The depths +// used here (10000) clearly exceed that limit, proving auto-TCO is active. + +// Direct function binding form: local sum(n, acc) = ... +local sum(n, acc) = + if n == 0 then acc + else sum(n - 1, acc + n); + +// Function literal binding form: local counter = function(n, acc) ... +local counter = function(n, acc) + if n == 0 then acc + else counter(n - 1, acc + 1); + +// Non-tail call: multiplication wraps the recursive call, so auto-TCO +// must NOT mark it (it would change semantics). Shallow depth is fine. +local factorial(n) = + if n <= 1 then 1 + else n * factorial(n - 1); + +// Multiple tail positions through if-else chains +local collatz_steps(n, steps) = + if n == 1 then steps + else if n % 2 == 0 then collatz_steps(n / 2, steps + 1) + else collatz_steps(3 * n + 1, steps + 1); + +// Tail call through local binding +local count_down(n) = + local next = n - 1; + if n == 0 then "done" + else count_down(next); + +// Depth 10000 >> maxStack(500): proves auto-TCO trampoline is active +std.assertEqual(sum(10000, 0), 50005000) && +// Function-literal form also gets auto-TCO +std.assertEqual(counter(10000, 0), 10000) && +// Non-tail recursion at safe depth, correctness check +std.assertEqual(factorial(10), 3628800) && +// Multi-branch tail calls +std.assertEqual(collatz_steps(27, 0), 111) && +// Tail call through local binding +std.assertEqual(count_down(10000), "done") && + +// Lazy semantics regression test: auto-TCO must NOT eagerly evaluate arguments. +// Without this fix, error("boom") would be eagerly evaluated and crash. +// With correct auto-TCO (TailstrictModeAutoTCO), args stay lazy and error("boom") +// is never evaluated because the function returns 1 on the second call. +local lazy_check(x, y) = + if x then 1 + else lazy_check(true, error "boom"); +std.assertEqual(lazy_check(false, 42), 1) && + +true diff --git a/sjsonnet/test/resources/new_test_suite/auto_tco.jsonnet.golden b/sjsonnet/test/resources/new_test_suite/auto_tco.jsonnet.golden new file mode 100644 index 00000000..27ba77dd --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/auto_tco.jsonnet.golden @@ -0,0 +1 @@ +true diff --git a/sjsonnet/test/resources/new_test_suite/auto_tco_directional.jsonnet b/sjsonnet/test/resources/new_test_suite/auto_tco_directional.jsonnet new file mode 100644 index 00000000..f6b06908 --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/auto_tco_directional.jsonnet @@ -0,0 +1,480 @@ +// Directional tests for automatic tail-call optimization (auto-TCO). +// These tests verify that auto-TCO: +// 1. CORRECTLY detects and optimizes tail-recursive calls +// 2. Does NOT incorrectly optimize non-tail calls +// 3. Works across all tail position patterns +// 4. Preserves lazy evaluation semantics +// 5. Handles edge cases correctly +// 6. Interacts correctly with explicit `tailstrict` annotation +// 7. Works with object-returning functions and comprehensions +// +// Tests use depth 10000 to prove trampoline is active (maxStack = 500). + +// ============================================================ +// SECTION 1: Tail position patterns +// ============================================================ + +// 1.1 Direct function binding (local f(args) = body) +local sum_accumulator(n, acc) = + if n <= 0 then acc + else sum_accumulator(n - 1, acc + n); + +// 1.2 Function literal binding (local f = function(args) body) +local count_down = function(n, acc) + if n <= 0 then acc + else count_down(n - 1, acc + 1); + +// 1.3 Multiple recursive calls — ONLY tail calls should be TCO'd +local mixed_calls(n) = + if n <= 0 then 0 + else + local non_tail = mixed_calls(n - 1); // NOT tail — wrapped in + + non_tail + n; // tail call: n is the result + +// 1.4 If-else chain with tail calls in ALL branches +local multi_branch(n) = + if n <= 0 then "zero" + else if n == 1 then "one" + else if n == 2 then "two" + else multi_branch(n - 1); + +// 1.5 Tail call through LocalExpr +local through_local(n) = + local x = n - 1; + local y = x; + if n <= 0 then 0 + else through_local(y); + +// 1.6 Tail call through deeply nested locals +local deeply_nested_locals(n) = + local a = n - 1; + local b = a; + local c = b; + local d = c; + local e = d; + if n <= 0 then 0 + else deeply_nested_locals(e); + +// 1.7 Tail call through AssertExpr +local through_assert(n) = + assert n >= 0; + if n <= 0 then 0 + else through_assert(n - 1); + +// 1.8 Tail call through || (rhs is tail position) +local through_or(n) = + if n <= 0 then true + else (false || through_or(n - 1)); + +// 1.9 Tail call through && (rhs is tail position) +local through_and(n) = + if n <= 0 then true + else (true && through_and(n - 1)); + +// 1.10 Nested || with tail call at deepest level +local nested_or(n) = + if n <= 0 then true + else (false || (false || (false || nested_or(n - 1)))); + +// 1.11 Nested && with tail call at deepest level +local nested_and(n) = + if n <= 0 then true + else (true && (true && (true && nested_and(n - 1)))); + +// 1.12 Mixed || and && — tail call in && rhs within || rhs +local mixed_or_and(n) = + if n <= 0 then true + else (false || (true && mixed_or_and(n - 1))); + +// 1.13 Mixed && and || — tail call in || rhs within && rhs +local mixed_and_or(n) = + if n <= 0 then true + else (true && (false || mixed_and_or(n - 1))); + +// 1.14 Complex boolean expression with tail call +local complex_bool(n) = + if n <= 0 then true + else ((true || false) && (false || true) && complex_bool(n - 1)); + +// 1.15 Tail call through multiple if-else-assert combinations +local combined_control_flow(n) = + assert n >= -1; + local x = n; + if x < 0 then -1 + else + assert x >= 0; + if x == 0 then 0 + else combined_control_flow(x - 1); + +// 1.16 Deeply nested if-else with tail call at bottom +local deep_if_chain(n) = + if n <= 0 then 0 + else + if n == 1 then 1 + else + if n == 2 then 2 + else + if n == 3 then 3 + else deep_if_chain(n - 1); + +// 1.17 Tail call in else branch only (asymmetric if) +local asymmetric_if(n) = + if n <= 0 then 0 + else asymmetric_if(n - 1); + +// 1.18 Tail call through error expression value +local error_value(n) = + if n <= 0 then error "done at " + n + else error_value(n - 1); + +// ============================================================ +// SECTION 2: Arity coverage (Apply0–Apply5+) +// ============================================================ + +// 2.1 Apply1 (1 parameter, despite the name — there's no true 0-param tail recursion) +local loop_zero_params(n) = + if n <= 0 then "done" + else loop_zero_params(n - 1); + +// 2.2 Apply1 (1 parameter) +local loop_one_param(n) = + if n <= 0 then n + else loop_one_param(n - 1); + +// 2.3 Apply2 (2 parameters) +local loop_two_params(n, acc) = + if n <= 0 then acc + else loop_two_params(n - 1, acc + 1); + +// 2.4 Apply3 (3 parameters) +local loop_three_params(n, acc, mult) = + if n <= 0 then acc * mult + else loop_three_params(n - 1, acc + 1, mult); + +// 2.5 Apply4 (4 parameters, generic Apply) +local loop_four_params(a, b, c, d) = + if a <= 0 then b + c + d + else loop_four_params(a - 1, b + 1, c + 1, d + 1); + +// 2.6 Apply5 (5 parameters) +local loop_five_params(a, b, c, d, e) = + if a <= 0 then b + c + d + e + else loop_five_params(a - 1, b + 1, c + 1, d + 1, e + 1); + +// ============================================================ +// SECTION 3: Argument passing modes +// ============================================================ + +// 3.1 Named arguments — rebindApply resolves to positional, so auto-TCO still applies +local with_named_args(a, b=0) = + if a <= 0 then b + else with_named_args(a=a - 1, b=b + 1); + +// 3.2 Default argument involvement +local with_default(a, b=0) = + if a <= 0 then b + else with_default(a - 1, b + 1); + +// ============================================================ +// SECTION 4: Explicit tailstrict + auto-TCO interaction +// ============================================================ + +// 4.1 User writes `tailstrict` on a self-recursive call. +// Auto-TCO should NOT double-mark (guard: `!a.tailstrict`). +// This uses explicit tailstrict (eager args), NOT auto-TCO (lazy args). +local explicit_tailstrict_self(n) = + if n <= 0 then 0 + else explicit_tailstrict_self(n - 1) tailstrict; + +// 4.2 Explicit tailstrict forces eager eval of ALL args, including unused ones. +// Without eager forcing, `error "kaboom"` would not be evaluated. +local force_error_check(x, y) = x; + +// 4.3 Mixed: some branches explicit tailstrict, some auto-TCO. +// Even branches: explicit tailstrict (eager). Odd branches: auto-TCO (lazy). +local mixed_explicit(n) = + if n <= 0 then 0 + else if n % 2 == 0 then mixed_explicit(n - 1) tailstrict + else mixed_explicit(n - 1); + +// ============================================================ +// SECTION 5: Object-returning and container-returning tail recursion +// ============================================================ +// These use an accumulator pattern so the recursive call IS in tail position +// (direct return), not in a Bind.rhs. + +// 5.1 Deeply tail-recursive function building an object via accumulator. +// Uses depth=400 (below maxStack=500) because object field access in lazy +// thunk args adds per-iteration overhead. The numeric accumulator tests +// (SECTION 2, 9) prove trampoline at depth > 500. +local build_obj(n, acc) = + if n <= 0 then acc + else build_obj(n - 1, { result: "done", count: acc.count + 1 }); + +// 5.2 Deeply tail-recursive function building an array via accumulator. +local build_arr(n, acc) = + if n <= 0 then acc + else build_arr(n - 1, acc + [n]); + +// 5.3 Nested object construction with tail-recursive field computation. +local nested_obj_builder(n) = { + level1: { + level2: { + value: build_obj(n, { count: 0 }).count, + }, + sibling: std.length(build_arr(n, [])), + }, +}; + +// 5.4 Object method that is tail-recursive and returns an object. +// This is NOT auto-TCO'd (self.method), so use shallow depth. +local obj_builder = { + build(n):: if n <= 0 then { n: 0 } else { n: n, prev: self.build(n - 1) }, +}; + +// ============================================================ +// SECTION 6: Nested functions, comprehensions, and stdlib callbacks +// ============================================================ + +// 6.1 Function defined inside another function, tail-recursive. +// The inner function's self-call should be auto-TCO'd. +local outer_factory(n) = + local inner(acc, remaining) = + if remaining <= 0 then acc + else inner(acc + remaining, remaining - 1); + inner(0, n); + +// 6.2 Comprehension body calls a tail-recursive function. +local sum_to(n, acc) = + if n <= 0 then acc + else sum_to(n - 1, acc + n); + +local comp_with_tco = [sum_to(i, 0) for i in [1, 2, 3, 4, 5]]; + +// 6.3 std.map with a tail-recursive callback. +local map_with_tco = std.map(function(x) sum_to(x, 0), [10, 20, 30]); + +// 6.4 std.foldl where the accumulator function is tail-recursive. +local fold_with_tco = std.foldl( + function(acc, x) acc + sum_to(x, 0), + [1, 2, 3], + 0 +); + +// 6.5 Nested comprehension with tail-recursive function. +local nested_comp = [[sum_to(j, 0) for j in [1, 2]] for i in [1, 2]]; + +// ============================================================ +// SECTION 7: Negative tests — ensure non-self-recursion is NOT auto-TCO'd +// ============================================================ + +// 7.1 Non-tail call (multiplication wraps recursive call) +local factorial_non_tail(n) = + if n <= 1 then 1 + else n * factorial_non_tail(n - 1); + +// 7.2 Non-tail call (addition wraps BOTH recursive calls — tree recursion) +local fibonacci_non_tail(n) = + if n <= 1 then 1 + else fibonacci_non_tail(n - 1) + fibonacci_non_tail(n - 2); + +// 7.3 Mutual recursion via object — NOT auto-TCO'd (different function names) +local mutual = { + even(n):: if n == 0 then true else self.odd(n - 1), + odd(n):: if n == 0 then false else self.even(n - 1), +}; + +// 7.4 Two different functions calling each other — NOT auto-TCO'd +local cross_funcs = { + a(n):: if n <= 0 then 0 else self.b(n - 1), + b(n):: if n <= 0 then 1 else self.a(n - 1), +}; + +// 7.5 Function calling a DIFFERENT function in tail position. +// The callee is tail-recursive, but the caller's call is NOT self-recursion. +local helper_func(n) = + if n <= 0 then 0 else helper_func(n - 1); + +local caller_func(n) = + if n <= 0 then 0 else helper_func(n); + +// 7.6 Function shadowing: inner binding shadows outer name. +// Each `f` should be analyzed independently. +local shadowing_test = + local f(n) = if n <= 0 then "outer" else f(n - 1); + local inner = + local f(n) = if n <= 0 then "inner" else f(n - 1); + f(5); + { outer_result: f(3), inner_result: inner }; + +// 7.7 Function captured via alias — still self-recursion (same scope). +local aliased_recursion(n) = + local go = aliased_recursion; // alias to self + if n <= 0 then 0 else go(n - 1); + +// ============================================================ +// SECTION 8: Lazy semantics — auto-TCO must NOT force argument evaluation +// ============================================================ + +// 8.1 Unused dangerous argument should never be evaluated +local lazy_recursive(flag, dangerous) = + if flag then "safe" + else lazy_recursive(true, dangerous); + +// 8.2 Short-circuit with error through || +local lazy_or(flag, dangerous) = + if flag then true + else (true || lazy_or(true, dangerous)); + +// 8.3 Lazy semantics through && with auto-TCO +local lazy_and(n, dangerous) = + if n <= 0 then true + else (true && lazy_and(n - 1, dangerous)); + +// 8.4 Lazy semantics through || with auto-TCO +local lazy_or_deep(n, dangerous) = + if n <= 0 then true + else (false || lazy_or_deep(n - 1, dangerous)); + +// ============================================================ +// SECTION 9: Performance verification — depth > maxStack proves trampoline +// ============================================================ + +// 9.1 Depth 600 (just above maxStack=500) — must use trampoline +local perf_600(n) = + if n <= 0 then "ok" else perf_600(n - 1); + +// 9.2 Depth 5000 — well above maxStack +local perf_5000(n, acc) = + if n <= 0 then acc else perf_5000(n - 1, acc + 1); + +// 9.3 Large depth through And chain — verifies hasNonRecursiveExit +// correctly handles nested boolean operators +local perf_and_chain(n) = + if n <= 0 then true + else (true && (true && (true && perf_and_chain(n - 1)))); + +// ============================================================ +// SECTION 10: Edge cases +// ============================================================ + +// 10.1 Function that returns itself (not a tail call) +local returns_function() = + function(x) x + 1; + +// 10.2 Object method self-recursion — NOT auto-TCO'd (self.method) +local obj_method = { + sum(n, acc):: if n <= 0 then acc else self.sum(n - 1, acc + n), +}; + +// 10.3 Non-tail call: string concat wraps recursive call +local string_wrap_non_tail(n) = + if n <= 0 then "0" else string_wrap_non_tail(n - 1) + ""; + +// 10.4 Many local bindings before the tail call +local many_locals(n) = + local a = n - 1; + local b = a + 0; + local c = b + 0; + local d = c + 0; + local e = d + 0; + local f = e + 0; + local g = f + 0; + local h = g + 0; + local i = h + 0; + local j = i + 0; + if n <= 0 then 0 else many_locals(j); + +// 10.5 Collatz sequence — different recursive calls in different branches +local collatz(n, steps) = + if n == 1 then steps + else if n % 2 == 0 then collatz(n / 2, steps + 1) + else collatz(3 * n + 1, steps + 1); + +// ============================================================ +// Run all tests +// ============================================================ +local depth = 10000; +local shallow = 20; + +// SECTION 1: Tail position patterns +std.assertEqual(sum_accumulator(depth, 0), 50005000) && +std.assertEqual(count_down(depth, 0), depth) && +std.assertEqual(mixed_calls(200), 20100) && +std.assertEqual(multi_branch(depth), "two") && +std.assertEqual(through_local(depth), 0) && +std.assertEqual(deeply_nested_locals(depth), 0) && +std.assertEqual(through_assert(depth), 0) && +std.assertEqual(through_or(depth), true) && +std.assertEqual(through_and(depth), true) && +std.assertEqual(nested_or(depth), true) && +std.assertEqual(nested_and(depth), true) && +std.assertEqual(mixed_or_and(depth), true) && +std.assertEqual(mixed_and_or(depth), true) && +std.assertEqual(complex_bool(depth), true) && +std.assertEqual(combined_control_flow(depth), 0) && +std.assertEqual(deep_if_chain(depth), 3) && +std.assertEqual(asymmetric_if(depth), 0) && + +// SECTION 2: Arity coverage +std.assertEqual(loop_zero_params(depth), "done") && +std.assertEqual(loop_one_param(depth), 0) && +std.assertEqual(loop_two_params(depth, 0), depth) && +std.assertEqual(loop_three_params(depth, 0, 1), depth) && +std.assertEqual(loop_four_params(depth, 0, 0, 0), depth * 3) && +std.assertEqual(loop_five_params(depth, 0, 0, 0, 0), depth * 4) && + +// SECTION 3: Argument passing modes +std.assertEqual(with_named_args(depth), depth) && +std.assertEqual(with_default(depth, 0), depth) && + +// SECTION 4: Explicit tailstrict + auto-TCO interaction +std.assertEqual(explicit_tailstrict_self(depth), 0) && +std.assertEqual(force_error_check(42, error "kaboom"), 42) && +std.assertEqual(mixed_explicit(shallow), 0) && + +// SECTION 5: Object/container-returning tail recursion +std.assertEqual(build_obj(400, { count: 0 }).count, 400) && +std.assertEqual(std.length(build_arr(400, [])), 400) && +std.assertEqual(nested_obj_builder(400).level1.level2.value, 400) && +std.assertEqual(nested_obj_builder(400).level1.sibling, 400) && +std.assertEqual(obj_builder.build(5).n, 5) && + +// SECTION 6: Nested functions, comprehensions, stdlib +std.assertEqual(outer_factory(depth), depth * (depth + 1) / 2) && +std.assertEqual(comp_with_tco, [1, 3, 6, 10, 15]) && +std.assertEqual(map_with_tco, [55, 210, 465]) && +std.assertEqual(fold_with_tco, 10) && +std.assertEqual(nested_comp, [[1, 3], [1, 3]]) && + +// SECTION 7: Negative tests +std.assertEqual(factorial_non_tail(10), 3628800) && +std.assertEqual(fibonacci_non_tail(20), 10946) && +std.assertEqual(mutual.even(10), true) && +std.assertEqual(mutual.odd(10), false) && +std.assertEqual(cross_funcs.a(10), 0) && +std.assertEqual(cross_funcs.b(10), 1) && +std.assertEqual(caller_func(shallow), 0) && +std.assertEqual(shadowing_test.outer_result, "outer") && +std.assertEqual(shadowing_test.inner_result, "inner") && +std.assertEqual(aliased_recursion(shallow), 0) && + +// SECTION 8: Lazy semantics +std.assertEqual(lazy_recursive(false, error "should not eval"), "safe") && +std.assertEqual(lazy_or(false, error "should not eval"), true) && +std.assertEqual(lazy_and(depth, error "should not eval"), true) && +std.assertEqual(lazy_or_deep(depth, error "should not eval"), true) && + +// SECTION 9: Performance verification +std.assertEqual(perf_600(600), "ok") && +std.assertEqual(perf_5000(5000, 0), 5000) && +std.assertEqual(perf_and_chain(1000), true) && + +// SECTION 10: Edge cases +std.isFunction(returns_function()) && +std.assertEqual(obj_method.sum(100, 0), 5050) && +std.assertEqual(string_wrap_non_tail(5), "0") && +std.assertEqual(many_locals(depth), 0) && +std.assertEqual(collatz(27, 0), 111) && + +true diff --git a/sjsonnet/test/resources/new_test_suite/auto_tco_directional.jsonnet.golden b/sjsonnet/test/resources/new_test_suite/auto_tco_directional.jsonnet.golden new file mode 100644 index 00000000..27ba77dd --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/auto_tco_directional.jsonnet.golden @@ -0,0 +1 @@ +true diff --git a/sjsonnet/test/resources/new_test_suite/auto_tco_patterns.jsonnet b/sjsonnet/test/resources/new_test_suite/auto_tco_patterns.jsonnet new file mode 100644 index 00000000..fb0513b8 --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/auto_tco_patterns.jsonnet @@ -0,0 +1,123 @@ +// Comprehensive test for automatic tail-call optimization (auto-TCO). +// Tests ALL tail position patterns that the StaticOptimizer should detect. +// +// Depth 10000 >> maxStack(500): proves auto-TCO trampoline is active. + +// 1. Direct function binding: local sum(n, acc) = ... +local sum(n, acc) = + if n == 0 then acc + else sum(n - 1, acc + n); + +// 2. Function literal binding: local counter = function(n, acc) ... +local counter = function(n, acc) + if n == 0 then acc + else counter(n - 1, acc + 1); + +// 3. Non-tail call: multiplication wraps the recursive call — auto-TCO must NOT mark it +local factorial(n) = + if n <= 1 then 1 + else n * factorial(n - 1); + +// 4. Multiple tail positions through if-else chains +local collatz_steps(n, steps) = + if n == 1 then steps + else if n % 2 == 0 then collatz_steps(n / 2, steps + 1) + else collatz_steps(3 * n + 1, steps + 1); + +// 5. Tail call through local binding +local count_down(n) = + local next = n - 1; + if n == 0 then "done" + else count_down(next); + +// 6. Tail call through || operator (rhs of || is tail position) +// Note: || requires boolean rhs, so we use a boolean-returning function +local f_or(n) = if n <= 0 then true else (false || f_or(n - 1)); + +// 7. Tail call through && operator (rhs of && is tail position) +local f_and(n) = if n <= 0 then true else (true && f_and(n - 1)); + +// 8. Nested || and && tail positions +local f_nested_or(n) = if n <= 0 then true else (false || (false || f_nested_or(n - 1))); +local f_nested_and(n) = if n <= 0 then true else (true && (true && f_nested_and(n - 1))); + +// 9. Mixed || and && tail positions +local f_mixed(n) = if n <= 0 then true else (true || (false && f_mixed(n - 1))); + +// 10. Tail call through assert (returned expr is tail position) +local f_assert(n) = + assert n >= 0; + if n == 0 then 0 else f_assert(n - 1); + +// 11. Lazy semantics: auto-TCO must NOT eagerly evaluate arguments +local lazy_check(x, y) = + if x then 1 + else lazy_check(true, error "boom"); + +// 12. Mutual recursion test (shallow depth — NOT auto-TCO'd, only direct self-recursion is) +// Using an object to make both functions visible to each other +local even_odd = { + even(n):: if n == 0 then true else self.odd(n - 1), + odd(n):: if n == 0 then false else self.even(n - 1), +}; + +// 13. Deeply nested local expressions +local f_nested_local(n) = + local a = n - 1; + local b = a; + local c = b; + if n <= 0 then 0 else f_nested_local(c); + +// 14. Tail call in else branch of deeply nested if-else +local f_deep_if(n) = + if n <= 0 then 0 + else + if n == 1 then 1 + else + if n == 2 then 2 + else f_deep_if(n - 1); + +// 15. Function with 0 params (Apply0) +local f_zero() = 42; + +// 16. Function with 4+ params (generic Apply) +local f_four(a, b, c, d) = + if a <= 0 then d + else f_four(a - 1, b, c, d + 1); + +// 17. Edge case: &&/|| with base case in outer if, recursive call in &&/|| rhs. +// This tests that hasNonRecursiveExit correctly identifies the outer if's base case +// (not the &&/|| rhs which always recurses). +local f_and_with_outer_base(n) = + if n <= 0 then 0 // outer base case + else (true && f_and_with_outer_base(n - 1)); // rhs always recurses, but outer if provides exit + +local f_or_with_outer_base(n) = + if n <= 0 then 0 + else (false || f_or_with_outer_base(n - 1)); + +// Run tests +std.assertEqual(sum(10000, 0), 50005000) && +std.assertEqual(counter(10000, 0), 10000) && +std.assertEqual(factorial(10), 3628800) && +std.assertEqual(collatz_steps(27, 0), 111) && +std.assertEqual(count_down(10000), "done") && +std.assertEqual(f_or(10000), true) && +std.assertEqual(f_and(10000), true) && +std.assertEqual(f_nested_or(10000), true) && +std.assertEqual(f_nested_and(10000), true) && +std.assertEqual(f_mixed(10000), true) && +std.assertEqual(f_assert(10000), 0) && +std.assertEqual(lazy_check(false, 42), 1) && +// Mutual recursion (shallow depth — NOT auto-TCO'd) +std.assertEqual(even_odd.even(10), true) && +std.assertEqual(even_odd.odd(10), false) && +std.assertEqual(f_nested_local(10000), 0) && +std.assertEqual(f_deep_if(10000), 2) && +std.assertEqual(f_zero(), 42) && +std.assertEqual(f_four(10000, 0, 0, 0), 10000) && +// Edge case: &&/|| with outer base case — auto-TCO should work because outer if provides exit +std.assertEqual(f_and_with_outer_base(10000), 0) && +std.assertEqual(f_or_with_outer_base(10000), 0) && + +true diff --git a/sjsonnet/test/resources/new_test_suite/auto_tco_patterns.jsonnet.golden b/sjsonnet/test/resources/new_test_suite/auto_tco_patterns.jsonnet.golden new file mode 100644 index 00000000..27ba77dd --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/auto_tco_patterns.jsonnet.golden @@ -0,0 +1 @@ +true diff --git a/sjsonnet/test/resources/new_test_suite/error.auto_tco_bool_check.jsonnet b/sjsonnet/test/resources/new_test_suite/error.auto_tco_bool_check.jsonnet new file mode 100644 index 00000000..7bd2b1be --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/error.auto_tco_bool_check.jsonnet @@ -0,0 +1,5 @@ +// Verify that && type check is enforced inside function bodies (auto-TCO path). +// Without the fix, visitExprWithTailCallSupport would skip the rhs type check +// and silently return "hello" instead of erroring. +local f() = true && "hello"; +f() diff --git a/sjsonnet/test/resources/new_test_suite/error.auto_tco_bool_check.jsonnet.golden b/sjsonnet/test/resources/new_test_suite/error.auto_tco_bool_check.jsonnet.golden new file mode 100644 index 00000000..a73f6a61 --- /dev/null +++ b/sjsonnet/test/resources/new_test_suite/error.auto_tco_bool_check.jsonnet.golden @@ -0,0 +1,2 @@ +sjsonnet.Error: binary operator && does not operate on strings. + at [f].(error.auto_tco_bool_check.jsonnet:4:18)