Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions examples/jld2_checkpoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
Demo: checkpointing and resuming LBFGS optimization with JLD2.

Usage (from the repo root):
julia --project=. test/jld2_checkpoint_demo.jl

What it shows:
1. Run LBFGS with a `checkpoint` callback that saves state to a JLD2 file
every N iterations.
2. Interrupt early (via `shouldstop`) to simulate a crashed job.
3. Load the last checkpoint from disk and resume to full convergence.
4. Verify the resumed result matches a reference run.
"""

using OptimKit
using LinearAlgebra
using JLD2 # `import Pkg; Pkg.add("JLD2")` if not yet installed

# ---------------------------------------------------------------------------
# Problem: minimise f(x) = ½ (x-y)ᵀ A (x-y)
# ---------------------------------------------------------------------------
function make_fg(A, y)
function fg(x)
r = x - y
g = A * r
f = dot(r, g) / 2
return f, g
end
return fg
end

# Reproducible random problem
import Random; Random.seed!(42)
n = 50
y = randn(n)
A = let B = randn(n, n); B'B + 5I end # positive-definite, well-conditioned
fg = make_fg(A, y)
x₀ = randn(n)
alg = LBFGS(; gradtol=1e-12, verbosity=0)

# ---------------------------------------------------------------------------
# Reference: run to convergence (ground truth)
# ---------------------------------------------------------------------------
x_ref, f_ref, _, _, _ = optimize(fg, x₀, alg)
println("Reference: f* = $f_ref, ‖x*-y‖ = $(norm(x_ref - y))")

# ---------------------------------------------------------------------------
# Helper: build a checkpoint callback that saves to `filepath` every
# `save_every` completed iterations using JLD2.
# ---------------------------------------------------------------------------
function make_jld2_checkpoint(filepath::String; save_every::Int=1)
function checkpoint(state::LBFGSState)
if mod(state.numiter, save_every) == 0
jldsave(filepath; state)
# Uncomment the line below to see checkpoint progress:
# println(" [checkpoint] saved iter $(state.numiter), f=$(state.f)")
end
end
return checkpoint
end

# ---------------------------------------------------------------------------
# Phase 1: run for up to 10 iterations, saving a checkpoint after each one
# ---------------------------------------------------------------------------
checkpoint_file = tempname() * ".jld2"

checkpoint_cb = make_jld2_checkpoint(checkpoint_file; save_every=1)
stop_at_10 = (x, f, g, numfg, numiter, t) -> numiter >= 10

x_part, f_part, _, numfg_part, history_part =
optimize(fg, x₀, alg;
checkpoint = checkpoint_cb,
shouldstop = stop_at_10,
hasconverged = (x, f, g, ng) -> ng <= 1e-12)

println("\nPhase 1 done: $(size(history_part,1)-1) iterations, f = $f_part")
println("Checkpoint file: $checkpoint_file ($(round(filesize(checkpoint_file)/1024, digits=1)) KB)")

# ---------------------------------------------------------------------------
# Phase 2: load checkpoint and resume to convergence
# ---------------------------------------------------------------------------
state_loaded = jldopen(checkpoint_file, "r") do file
file["state"]
end

println("\nLoaded checkpoint: numiter=$(state_loaded.numiter), numfg=$(state_loaded.numfg)")
println(" fhistory length = $(length(state_loaded.fhistory)) (should be numiter+1)")
println(" H length = $(length(state_loaded.H)) (LBFGS memory used)")

x_resumed, f_resumed, _, numfg_resumed, history_resumed =
optimize(fg, state_loaded, alg)

println("\nPhase 2 done: total $(size(history_resumed,1)-1) iterations, f = $f_resumed")
println(" numfg (total) = $numfg_resumed")

# ---------------------------------------------------------------------------
# Sanity checks
# ---------------------------------------------------------------------------
@assert x_resumed ≈ x_ref rtol=1e-8 "resumed solution differs from reference"
@assert f_resumed ≈ f_ref rtol=1e-8 "resumed f* differs from reference"
@assert history_resumed[1:size(history_part,1), :] ≈ history_part "history mismatch"

println("\n✓ All checks passed — resumed result matches reference run.")

# Clean up temp file
rm(checkpoint_file)
1 change: 1 addition & 0 deletions src/OptimKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ const lbfgs = LBFGS()

export optimize, gd, cg, lbfgs, optimtest
export GradientDescent, ConjugateGradient, LBFGS
export LBFGSState
export FletcherReeves, HestenesStiefel, PolakRibiere, HagerZhang, DaiYuan
export HagerZhangLineSearch

Expand Down
134 changes: 124 additions & 10 deletions src/lbfgs.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
LBFGS(m::Int = 8;
LBFGS(m::Int = 8;
acceptfirst::Bool = true,
maxiter::Int=MAXITER[], # 1_000_000
gradtol::Real=GRADTOL[], # 1e-8
Expand Down Expand Up @@ -53,9 +53,63 @@ function LBFGS(m::Int=8;
return LBFGS(m, maxiter, gradtol, acceptfirst, verbosity, linesearch)
end

"""
LBFGSState

Captures the complete state of an LBFGS optimization, enabling checkpointing and
warm-starting. Instances are produced by the `checkpoint` callback passed to
[`optimize`](@ref), and can be passed back as the starting point to resume optimization.

## Fields
- `x`: Current parameter values
- `f`: Current function value
- `g`: Current gradient
- `H`: Current LBFGS inverse Hessian approximation (`LBFGSInverseHessian`)
- `numfg`: Cumulative number of function/gradient evaluations so far
- `numiter`: Cumulative number of completed iterations
- `fhistory`: History of function values (one entry per iteration)
- `normgradhistory`: History of gradient norms (one entry per iteration)

## Example

Periodic checkpointing using `Serialization` from the standard library:

```julia
using Serialization, OptimKit

checkpoint_fn = state -> serialize("checkpoint.jls", state)
x, f, g, numfg, history = optimize(fg, x0, LBFGS(); checkpoint=checkpoint_fn)

# resume from the last checkpoint
state = deserialize("checkpoint.jls")
x, f, g, numfg, history = optimize(fg, state, LBFGS())
```

!!! note
The `LBFGSState` struct stores references to the arrays `x`, `g`, and the vectors
inside `H`. When using GPU arrays or other non-standard backends, ensure your
serialization method handles those array types correctly.

!!! note
When resuming, the `shouldstop` and `hasconverged` callbacks receive the *cumulative*
`numfg` and `numiter` values from the original run. Pass a custom `shouldstop` if you
need a fixed number of *additional* iterations.
"""
struct LBFGSState{X,G,F<:Real,H}
x::X
f::F
g::G
H::H
numfg::Int
numiter::Int
fhistory::Vector{F}
normgradhistory::Vector{F}
end

function optimize(fg, x, alg::LBFGS;
precondition=_precondition,
(finalize!)=_finalize!,
checkpoint=nothing,
shouldstop=DefaultShouldStop(alg.maxiter),
hasconverged=DefaultHasConverged(alg.gradtol),
retract=_retract, inner=_inner, (transport!)=_transport!,
Expand All @@ -70,15 +124,66 @@ function optimize(fg, x, alg::LBFGS;
normgrad = sqrt(innergg)
fhistory = [f]
normgradhistory = [normgrad]
t = time() - t₀
_hasconverged = hasconverged(x, f, g, normgrad)
_shouldstop = shouldstop(x, f, g, numfg, numiter, t)

TangentType = typeof(g)
ScalarType = typeof(innergg)
m = alg.m
H = LBFGSInverseHessian(m, TangentType[], TangentType[], ScalarType[])

return _lbfgs_loop!(fg, x, f, g, H, numfg, numiter, normgrad, fhistory,
normgradhistory, t₀, alg,
precondition, finalize!, checkpoint,
shouldstop, hasconverged,
retract, inner, transport!, scale!, add!,
isometrictransport)
end

"""
optimize(fg, state::LBFGSState, alg::LBFGS; kwargs...) -> x, f, g, numfg, history

Resume an LBFGS optimization from a previously saved [`LBFGSState`](@ref). All keyword
arguments are the same as for the standard `optimize` call. The `numfg`, `numiter`,
`fhistory`, and `normgradhistory` are continued from the checkpoint; the returned
`history` matrix covers the full run including prior iterations.
"""
function optimize(fg, state::LBFGSState, alg::LBFGS;
precondition=_precondition,
(finalize!)=_finalize!,
checkpoint=nothing,
shouldstop=DefaultShouldStop(alg.maxiter),
hasconverged=DefaultHasConverged(alg.gradtol),
retract=_retract, inner=_inner, (transport!)=_transport!,
(scale!)=_scale!, (add!)=_add!,
isometrictransport=(transport! == _transport! && inner == _inner))
t₀ = time()
x = state.x
f = state.f
g = state.g
H = deepcopy(state.H)
numfg = state.numfg
numiter = state.numiter
normgrad = state.normgradhistory[end]
fhistory = copy(state.fhistory)
normgradhistory = copy(state.normgradhistory)

return _lbfgs_loop!(fg, x, f, g, H, numfg, numiter, normgrad, fhistory,
normgradhistory, t₀, alg,
precondition, finalize!, checkpoint,
shouldstop, hasconverged,
retract, inner, transport!, scale!, add!,
isometrictransport)
end

function _lbfgs_loop!(fg, x, f, g, H, numfg, numiter, normgrad, fhistory, normgradhistory,
t₀, alg::LBFGS,
precondition, finalize!, checkpoint,
shouldstop, hasconverged,
retract, inner, transport!, scale!, add!, isometrictransport)
verbosity = alg.verbosity
t = time() - t₀
_hasconverged = hasconverged(x, f, g, normgrad)
_shouldstop = shouldstop(x, f, g, numfg, numiter, t)

verbosity >= 2 &&
@info @sprintf("LBFGS: initializing with f = %.12e, ‖∇f‖ = %.4e", f, normgrad)

Expand Down Expand Up @@ -122,13 +227,12 @@ function optimize(fg, x, alg::LBFGS;
_hasconverged = hasconverged(x, f, g, normgrad)
_shouldstop = shouldstop(x, f, g, numfg, numiter, t)

# check stopping criteria and print info
if _hasconverged || _shouldstop
break
# print iteration info if continuing (preserves original verbosity behavior)
if !(_hasconverged || _shouldstop)
verbosity >= 3 &&
@info @sprintf("LBFGS: iter %4d, Δt %s: f = %.12e, ‖∇f‖ = %.4e, α = %.2e, m = %d, nfg = %d",
numiter, format_time(Δt), f, normgrad, α, length(H), nfg)
end
verbosity >= 3 &&
@info @sprintf("LBFGS: iter %4d, Δt %s: f = %.12e, ‖∇f‖ = %.4e, α = %.2e, m = %d, nfg = %d",
numiter, format_time(Δt), f, normgrad, α, length(H), nfg)

# transport gprev, ηprev and vectors in Hessian approximation to x
gprev = transport!(gprev, xprev, ηprev, α, x)
Expand Down Expand Up @@ -189,6 +293,16 @@ function optimize(fg, x, alg::LBFGS;
ρ = innerss / innersy
push!(H, (scale!(s, 1 / norms), scale!(y, 1 / norms), ρ))
end

# checkpoint after H is updated; called every iteration including the last
if !isnothing(checkpoint)
checkpoint(LBFGSState(x, f, g, H, numfg, numiter, fhistory, normgradhistory))
end

# break after checkpoint so the final state is always captured
if _hasconverged || _shouldstop
break
end
end
if _hasconverged
verbosity >= 2 &&
Expand Down
54 changes: 54 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,60 @@ algorithms = (GradientDescent, ConjugateGradient, LBFGS)
@test f < 1e-12
end

@testset "LBFGS checkpoint and resume" begin
n = 20
y = randn(n)
A = let B = randn(n, n); B' * B + I end
fg = quadraticproblem(A, y)
x₀ = randn(n)
alg = LBFGS(; verbosity=0, gradtol=1e-12, maxiter=10_000_000)

# Run to full convergence as ground truth
x_full, f_full, g_full, numfg_full, history_full = optimize(fg, x₀, alg)

# Run with early stopping after 5 iterations and collect checkpoint
saved_states = LBFGSState[]
checkpoint_fn = state -> push!(saved_states, state)
stop_after_5 = (x, f, g, numfg, numiter, t) -> numiter >= 5
converged_1e12 = (x, f, g, normgrad) -> normgrad <= 1e-12
x_part, f_part, g_part, numfg_part, history_part =
optimize(fg, x₀, alg; checkpoint=checkpoint_fn, shouldstop=stop_after_5,
hasconverged=converged_1e12)

# Checkpoint is called once per completed iteration
@test length(saved_states) == 5

# Checkpoint state at iteration 5 matches optimize's returned state
state5 = saved_states[end]
@test state5.numiter == 5
@test state5.x ≈ x_part
@test state5.f ≈ f_part
@test state5.numfg == numfg_part
@test length(state5.fhistory) == 6 # initial + 5 iterations
@test length(state5.normgradhistory) == 6

# Resume from checkpoint and run to convergence; result must match full run
x_resumed, f_resumed, g_resumed, numfg_resumed, history_resumed =
optimize(fg, state5, alg)
@test x_resumed ≈ x_full rtol = 1e-10
@test f_resumed ≈ f_full rtol = 1e-10

# Resumed history prepends the prior run's history
@test size(history_resumed, 1) == size(history_full, 1)
@test history_resumed[1:6, :] ≈ history_part # first 6 rows identical to partial run
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is really identical, I guess this could be:

Suggested change
@test history_resumed[1:6, :] history_part # first 6 rows identical to partial run
@test history_resumed[1:6, :] == history_part # first 6 rows identical to partial run


# Resume with additional checkpoint continues counting from previous numiter
extra_states = LBFGSState[]
stop_after_3_more = (x, f, g, numfg, numiter, t) -> numiter >= state5.numiter + 3
optimize(fg, state5, alg;
checkpoint=state -> push!(extra_states, state),
shouldstop=stop_after_3_more,
hasconverged=converged_1e12)
@test length(extra_states) == 3
@test extra_states[1].numiter == 6
@test extra_states[end].numiter == 8
end

@testset "Aqua" verbose = true begin
using Aqua
Aqua.test_all(OptimKit)
Expand Down
Loading