diff --git a/go.mod b/go.mod index 7d25faaf41..3a65dd8a53 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,8 @@ require ( require github.com/getsentry/sentry-go/otel v0.44.1 +require github.com/hashicorp/golang-lru/v2 v2.0.7 + require ( cel.dev/expr v0.25.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect @@ -298,7 +300,7 @@ require ( gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - k8s.io/apiextensions-apiserver v0.35.0 // indirect + k8s.io/apiextensions-apiserver v0.35.0 k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect modernc.org/libc v1.70.0 // indirect diff --git a/pkg/cache/validating_cache.go b/pkg/cache/validating_cache.go new file mode 100644 index 0000000000..eb6dbd1aa9 --- /dev/null +++ b/pkg/cache/validating_cache.go @@ -0,0 +1,173 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package cache provides a generic, capacity-bounded cache with singleflight +// deduplication and per-hit liveness validation. +package cache + +import ( + "errors" + "fmt" + + lru "github.com/hashicorp/golang-lru/v2" + "golang.org/x/sync/singleflight" +) + +// ErrExpired is returned by the check function passed to New to signal that a +// cached entry has definitively expired and should be evicted. +var ErrExpired = errors.New("cache entry expired") + +// ValidatingCache is a node-local write-through cache backed by a +// capacity-bounded LRU map, with singleflight-deduplicated restore on cache +// miss and lazy liveness validation on cache hit. +// +// Type parameter K is the key type (must be comparable). +// Type parameter V is the cached value type. +// +// The no-resurrection invariant (preventing a concurrent restore from +// overwriting a deletion) is enforced via ContainsOrAdd: if a concurrent +// writer stored a value between load() returning and the cache being updated, +// the prior writer's value wins and the just-loaded value is discarded via +// onEvict. +type ValidatingCache[K comparable, V any] struct { + lruCache *lru.Cache[K, V] + flight singleflight.Group + load func(key K) (V, error) + check func(key K, val V) error + // onEvict is kept here so we can call it when discarding a concurrently + // loaded value that lost the race to a prior writer. + onEvict func(key K, val V) +} + +// New creates a ValidatingCache with the given capacity and callbacks. +// +// capacity is the maximum number of entries; it must be >= 1. When the cache +// is full and a new entry must be stored, the least-recently-used entry is +// evicted first. Values less than 1 panic. +// +// load is called on a cache miss to restore the value; it must not be nil. +// check is called on every cache hit to confirm liveness. It receives both the +// key and the cached value so callers can inspect the value without a separate +// read. Returning ErrExpired evicts the entry; any other error is transient +// (cached value returned unchanged). It must not be nil. +// onEvict is called after any eviction (LRU or expiry); it may be nil. +func New[K comparable, V any]( + capacity int, + load func(K) (V, error), + check func(K, V) error, + onEvict func(K, V), +) *ValidatingCache[K, V] { + if capacity < 1 { + panic(fmt.Sprintf("cache.New: capacity must be >= 1, got %d", capacity)) + } + if load == nil { + panic("cache.New: load must not be nil") + } + if check == nil { + panic("cache.New: check must not be nil") + } + + c, err := lru.NewWithEvict(capacity, onEvict) + if err != nil { + // Only possible if size < 0, which we have already ruled out above. + panic(fmt.Sprintf("cache.New: lru.NewWithEvict: %v", err)) + } + + return &ValidatingCache[K, V]{ + lruCache: c, + load: load, + check: check, + onEvict: onEvict, + } +} + +// getHit validates a known-present cache entry and returns its value. +// If the entry has definitively expired it is evicted and (zero, false) is +// returned. Transient check errors leave the entry in place and return the +// cached value. +func (c *ValidatingCache[K, V]) getHit(key K, val V) (V, bool) { + if err := c.check(key, val); err != nil { + if errors.Is(err, ErrExpired) { + // Remove fires the eviction callback automatically. + c.lruCache.Remove(key) + var zero V + return zero, false + } + } + return val, true +} + +// Get returns the value for key, loading it on a cache miss. On a cache hit +// the entry's liveness is validated via the check function provided to New: +// ErrExpired evicts the entry and returns (zero, false); transient errors +// return the cached value unchanged. On a cache miss, load is called under a +// singleflight group so at most one restore runs concurrently per key. +func (c *ValidatingCache[K, V]) Get(key K) (V, bool) { + if val, ok := c.lruCache.Get(key); ok { + return c.getHit(key, val) + } + + // Cache miss: use singleflight to prevent concurrent restores for the same key. + type result struct{ v V } + raw, err, _ := c.flight.Do(fmt.Sprint(key), func() (any, error) { + // Re-check the cache: a concurrent singleflight group may have stored + // the value between our miss check above and acquiring this group. + if existing, ok := c.lruCache.Get(key); ok { + return result{v: existing}, nil + } + + v, loadErr := c.load(key) + if loadErr != nil { + return nil, loadErr + } + + // Guard against a concurrent Set or Remove that occurred while load() was + // running. ContainsOrAdd stores only if absent; if another writer got + // in first, their value wins and we discard ours via onEvict. + ok, _ := c.lruCache.ContainsOrAdd(key, v) + if ok { + // Another writer stored a value first; discard our loaded value and + // return the winner's. ContainsOrAdd and Get are separate lock + // acquisitions, so the winner may itself have been evicted by LRU + // pressure between the two calls — fall back to our freshly loaded + // value in that case rather than returning a zero value. + winner, found := c.lruCache.Get(key) + if !found { + // Winner was evicted between ContainsOrAdd and Get; keep our + // freshly loaded value rather than returning a zero value. + return result{v: v}, nil + } + // Discard our loaded value in favour of the winner. + if c.onEvict != nil { + c.onEvict(key, v) + } + return result{v: winner}, nil + } + + return result{v: v}, nil + }) + if err != nil { + var zero V + return zero, false + } + r, ok := raw.(result) + return r.v, ok +} + +// Set stores value under key, moving the entry to the MRU position. If the +// cache is at capacity, the least-recently-used entry is evicted first and +// onEvict is called for it. +func (c *ValidatingCache[K, V]) Set(key K, value V) { + c.lruCache.Add(key, value) +} + +// Remove evicts the entry for key, calling onEvict if the key was present. +// It is a no-op if the key is not in the cache. +func (c *ValidatingCache[K, V]) Remove(key K) { + c.lruCache.Remove(key) +} + +// Len returns the number of entries currently in the cache. +func (c *ValidatingCache[K, V]) Len() int { + return c.lruCache.Len() +} diff --git a/pkg/vmcp/server/sessionmanager/cache_test.go b/pkg/cache/validating_cache_test.go similarity index 50% rename from pkg/vmcp/server/sessionmanager/cache_test.go rename to pkg/cache/validating_cache_test.go index 1123499b96..209d8769d3 100644 --- a/pkg/vmcp/server/sessionmanager/cache_test.go +++ b/pkg/cache/validating_cache_test.go @@ -1,10 +1,11 @@ // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -package sessionmanager +package cache import ( "errors" + "fmt" "sync" "sync/atomic" "testing" @@ -13,27 +14,41 @@ import ( "github.com/stretchr/testify/require" ) -// sentinel type used to test that non-V values stored via Store are -// invisible to Get without triggering a restore. -type testSentinel struct{} - -// newStringCache builds a RestorableCache[string, string] for tests. +// newStringCache builds a ValidatingCache[string, string] for tests. func newStringCache( load func(string) (string, error), - check func(string) error, + check func(string, string) error, evict func(string, string), -) *RestorableCache[string, string] { - return newRestorableCache(load, check, evict) +) *ValidatingCache[string, string] { + return New(1000, load, check, evict) } // alwaysAliveCheck returns a check function that always reports the entry as alive. -func alwaysAliveCheck(_ string) error { return nil } +func alwaysAliveCheck(_ string, _ string) error { return nil } + +// --------------------------------------------------------------------------- +// Construction invariants +// --------------------------------------------------------------------------- + +func TestValidatingCache_New_PanicsOnZeroCapacity(t *testing.T) { + t.Parallel() + assert.Panics(t, func() { + New(0, func(_ string) (string, error) { return "", nil }, alwaysAliveCheck, nil) + }) +} + +func TestValidatingCache_New_PanicsOnNegativeCapacity(t *testing.T) { + t.Parallel() + assert.Panics(t, func() { + New(-1, func(_ string) (string, error) { return "", nil }, alwaysAliveCheck, nil) + }) +} // --------------------------------------------------------------------------- // Cache miss / restore // --------------------------------------------------------------------------- -func TestRestorableCache_CacheMiss_CallsLoad(t *testing.T) { +func TestValidatingCache_CacheMiss_CallsLoad(t *testing.T) { t.Parallel() loaded := false @@ -52,7 +67,7 @@ func TestRestorableCache_CacheMiss_CallsLoad(t *testing.T) { assert.True(t, loaded) } -func TestRestorableCache_CacheMiss_StoresResult(t *testing.T) { +func TestValidatingCache_CacheMiss_StoresResult(t *testing.T) { t.Parallel() calls := 0 @@ -70,7 +85,7 @@ func TestRestorableCache_CacheMiss_StoresResult(t *testing.T) { assert.Equal(t, 1, calls, "load should be called only once after caching") } -func TestRestorableCache_CacheMiss_LoadError_ReturnsNotFound(t *testing.T) { +func TestValidatingCache_CacheMiss_LoadError_ReturnsNotFound(t *testing.T) { t.Parallel() loadErr := errors.New("not found") @@ -89,7 +104,7 @@ func TestRestorableCache_CacheMiss_LoadError_ReturnsNotFound(t *testing.T) { // Cache hit / liveness // --------------------------------------------------------------------------- -func TestRestorableCache_CacheHit_AliveCheck_ReturnsCached(t *testing.T) { +func TestValidatingCache_CacheHit_AliveCheck_ReturnsCached(t *testing.T) { t.Parallel() c := newStringCache( @@ -105,14 +120,14 @@ func TestRestorableCache_CacheHit_AliveCheck_ReturnsCached(t *testing.T) { assert.Equal(t, "loaded-k", v) } -func TestRestorableCache_CacheHit_Expired_EvictsAndCallsOnEvict(t *testing.T) { +func TestValidatingCache_CacheHit_Expired_EvictsAndCallsOnEvict(t *testing.T) { t.Parallel() evictedKey := "" evictedVal := "" c := newStringCache( func(_ string) (string, error) { return "v", nil }, - func(_ string) error { return ErrExpired }, + func(_ string, _ string) error { return ErrExpired }, func(key, val string) { evictedKey = key evictedVal = val @@ -127,7 +142,7 @@ func TestRestorableCache_CacheHit_Expired_EvictsAndCallsOnEvict(t *testing.T) { assert.Equal(t, "v", evictedVal) } -func TestRestorableCache_CacheHit_Expired_EntryRemovedFromCache(t *testing.T) { +func TestValidatingCache_CacheHit_Expired_EntryRemovedFromCache(t *testing.T) { t.Parallel() calls := 0 @@ -137,7 +152,7 @@ func TestRestorableCache_CacheHit_Expired_EntryRemovedFromCache(t *testing.T) { calls++ return "v", nil }, - func(_ string) error { + func(_ string, _ string) error { if expired { return ErrExpired } @@ -155,12 +170,12 @@ func TestRestorableCache_CacheHit_Expired_EntryRemovedFromCache(t *testing.T) { assert.Equal(t, 2, calls, "load should be called twice: initial + after eviction") } -func TestRestorableCache_CacheHit_TransientCheckError_ReturnsCached(t *testing.T) { +func TestValidatingCache_CacheHit_TransientCheckError_ReturnsCached(t *testing.T) { t.Parallel() c := newStringCache( func(_ string) (string, error) { return "v", nil }, - func(_ string) error { return errors.New("transient storage error") }, + func(_ string, _ string) error { return errors.New("transient storage error") }, nil, ) c.Get("k") //nolint:errcheck // prime the cache @@ -171,66 +186,54 @@ func TestRestorableCache_CacheHit_TransientCheckError_ReturnsCached(t *testing.T } // --------------------------------------------------------------------------- -// Sentinel / raw access +// Set // --------------------------------------------------------------------------- -func TestRestorableCache_Sentinel_GetReturnsNotFound(t *testing.T) { +func TestValidatingCache_Set_StoresValue(t *testing.T) { t.Parallel() - loadCalled := false - c := newRestorableCache( - func(_ string) (string, error) { - loadCalled = true - return "", errors.New("should not be called") - }, + c := newStringCache( + func(_ string) (string, error) { return "", errors.New("should not call load") }, alwaysAliveCheck, nil, ) - c.Store("k", testSentinel{}) + c.Set("k", "v") v, ok := c.Get("k") - assert.False(t, ok, "sentinel should not satisfy type assertion to V") - assert.Empty(t, v) - assert.False(t, loadCalled, "load should not be called when a sentinel is present") + require.True(t, ok) + assert.Equal(t, "v", v) } -func TestRestorableCache_Peek_ReturnsSentinel(t *testing.T) { +func TestValidatingCache_Set_UpdatesExisting(t *testing.T) { t.Parallel() - c := newRestorableCache( - func(string) (string, error) { return "", nil }, + c := newStringCache( + func(_ string) (string, error) { return "loaded", nil }, alwaysAliveCheck, nil, ) + c.Get("k") //nolint:errcheck // prime with "loaded" + c.Set("k", "updated") - c.Store("k", testSentinel{}) - - raw, ok := c.Peek("k") + v, ok := c.Get("k") require.True(t, ok) - _, isSentinel := raw.(testSentinel) - assert.True(t, isSentinel) + assert.Equal(t, "updated", v) } -// TestRestorableCache_Sentinel_StoredDuringLoad verifies that a sentinel stored -// concurrently during load() is respected: load() should not overwrite the -// sentinel, and the loaded value should be discarded via onEvict. -func TestRestorableCache_Sentinel_StoredDuringLoad(t *testing.T) { +// --------------------------------------------------------------------------- +// LRU capacity +// --------------------------------------------------------------------------- + +func TestValidatingCache_LRU_EvictsLeastRecentlyUsed(t *testing.T) { t.Parallel() var evictedKeys []string var mu sync.Mutex - sentinelReady := make(chan struct{}) - loadStarted := make(chan struct{}) - - c := newRestorableCache( - func(_ string) (string, error) { - // Signal that load has started, then wait for the sentinel to be stored. - close(loadStarted) - <-sentinelReady - return "loaded-value", nil - }, + // capacity=2: inserting a third entry evicts the LRU. + c := New(2, + func(key string) (string, error) { return "val-" + key, nil }, alwaysAliveCheck, func(key, _ string) { mu.Lock() @@ -239,158 +242,152 @@ func TestRestorableCache_Sentinel_StoredDuringLoad(t *testing.T) { }, ) - done := make(chan struct{}) - go func() { - defer close(done) - v, ok := c.Get("k") - // The sentinel should have blocked the store; Get returns not-found. - assert.False(t, ok) - assert.Empty(t, v) - }() - - // Wait until load() has started, then inject a sentinel before it stores. - <-loadStarted - c.Store("k", testSentinel{}) - close(sentinelReady) - <-done - - // The sentinel must still be in the cache (not overwritten by the loaded value). - raw, ok := c.Peek("k") - require.True(t, ok) - _, isSentinel := raw.(testSentinel) - assert.True(t, isSentinel, "sentinel must not be overwritten by the restore") + c.Get("a") //nolint:errcheck // a=MRU + c.Get("b") //nolint:errcheck // b=MRU, a=LRU + c.Get("c") //nolint:errcheck // c=MRU, b, a=LRU → evicts a - // onEvict must have been called for the discarded loaded value. mu.Lock() defer mu.Unlock() - assert.Equal(t, []string{"k"}, evictedKeys, "loaded value must be evicted when sentinel is present") + assert.Equal(t, []string{"a"}, evictedKeys, "LRU entry (a) should be evicted") + + // a is evicted; b and c remain. + _, bPresent := c.Get("b") + assert.True(t, bPresent) + _, cPresent := c.Get("c") + assert.True(t, cPresent) } -// TestRestorableCache_Sentinel_BlocksRestoreViaInitialHit verifies that a -// sentinel already present in the cache when Get is called causes load() to be -// skipped and Get to return not-found. This exercises the initial-hit branch -// (the outer c.m.Load check), which short-circuits before entering the -// singleflight group. -// -// The singleflight re-check branch (c.m.Load inside flight.Do) has structurally -// identical logic: if the stored value is not a V, errSentinelFound is returned -// and load is not called. That branch cannot be targeted deterministically from -// outside without code instrumentation, because the re-check runs in the same -// goroutine as the initial miss with no synchronisation point between them. -// The sentinel-stored-during-load path (TestRestorableCache_Sentinel_StoredDuringLoad) -// and the LoadOrStore guard cover the concurrent-store window that follows. -func TestRestorableCache_Sentinel_BlocksRestoreViaInitialHit(t *testing.T) { +func TestValidatingCache_LRU_GetRefreshesMRUPosition(t *testing.T) { t.Parallel() - loadCalled := false - c := newRestorableCache( - func(_ string) (string, error) { - loadCalled = true - return "loaded", nil - }, + var evictedKeys []string + var mu sync.Mutex + + c := New(2, + func(key string) (string, error) { return "val-" + key, nil }, alwaysAliveCheck, - nil, + func(key, _ string) { + mu.Lock() + evictedKeys = append(evictedKeys, key) + mu.Unlock() + }, ) - // Sentinel is present before Get is called: the initial c.m.Load hit path - // returns (zero, false) without entering the singleflight group. - c.Store("k", testSentinel{}) + c.Get("a") //nolint:errcheck // a loaded (MRU) + c.Get("b") //nolint:errcheck // b loaded (MRU), a=LRU + c.Get("a") //nolint:errcheck // a accessed → a becomes MRU, b=LRU + c.Get("c") //nolint:errcheck // c loaded → evicts b (LRU), not a - v, ok := c.Get("k") - assert.False(t, ok, "Get must return not-found when sentinel is present") - assert.Empty(t, v) - assert.False(t, loadCalled, "load must not be called when a sentinel is in the cache") + mu.Lock() + defer mu.Unlock() + assert.Equal(t, []string{"b"}, evictedKeys, "b should be evicted (LRU after a was re-accessed)") + + _, aPresent := c.Get("a") + assert.True(t, aPresent, "a should still be in cache") } -func TestRestorableCache_Peek_MissingKey_ReturnsFalse(t *testing.T) { +func TestValidatingCache_LRU_SetRefreshesMRUPosition(t *testing.T) { t.Parallel() - c := newStringCache( - func(string) (string, error) { return "", nil }, + var evictedKeys []string + var mu sync.Mutex + + c := New(2, + func(key string) (string, error) { return "val-" + key, nil }, alwaysAliveCheck, - nil, + func(key, _ string) { + mu.Lock() + evictedKeys = append(evictedKeys, key) + mu.Unlock() + }, ) - _, ok := c.Peek("absent") - assert.False(t, ok) -} + c.Get("a") //nolint:errcheck // a=MRU + c.Get("b") //nolint:errcheck // b=MRU, a=LRU + c.Set("a", "x") // Set refreshes a to MRU; b becomes LRU + c.Get("c") //nolint:errcheck // c loaded → evicts b -// --------------------------------------------------------------------------- -// CompareAndSwap -// --------------------------------------------------------------------------- + mu.Lock() + defer mu.Unlock() + assert.Equal(t, []string{"b"}, evictedKeys) +} -func TestRestorableCache_CompareAndSwap_Success(t *testing.T) { +func TestValidatingCache_LRU_CapacityOne(t *testing.T) { t.Parallel() - c := newStringCache( - func(_ string) (string, error) { return "v1", nil }, + var evictedKeys []string + var mu sync.Mutex + + c := New(1, + func(key string) (string, error) { return "val-" + key, nil }, alwaysAliveCheck, - nil, + func(key, _ string) { + mu.Lock() + evictedKeys = append(evictedKeys, key) + mu.Unlock() + }, ) - c.Get("k") //nolint:errcheck // prime with "v1" - swapped := c.CompareAndSwap("k", "v1", "v2") - require.True(t, swapped) + c.Get("a") //nolint:errcheck + c.Get("b") //nolint:errcheck // evicts a + c.Get("c") //nolint:errcheck // evicts b - raw, ok := c.Peek("k") - require.True(t, ok) - assert.Equal(t, "v2", raw) + mu.Lock() + defer mu.Unlock() + assert.Equal(t, []string{"a", "b"}, evictedKeys) } -func TestRestorableCache_CompareAndSwap_WrongOld_Fails(t *testing.T) { +func TestValidatingCache_LRU_LargeCapacityNoEviction(t *testing.T) { t.Parallel() - c := newStringCache( - func(_ string) (string, error) { return "v1", nil }, + const n = 100 + c := New(n+1, + func(key string) (string, error) { return "val-" + key, nil }, alwaysAliveCheck, - nil, + func(key, _ string) { + t.Errorf("unexpected eviction for key %s", key) + }, ) - c.Get("k") //nolint:errcheck - swapped := c.CompareAndSwap("k", "wrong", "v2") - assert.False(t, swapped) + for i := range n { + c.Get(fmt.Sprintf("k%d", i)) //nolint:errcheck + } + assert.Equal(t, n, c.Len(), "no entries should be evicted when under capacity") } -// --------------------------------------------------------------------------- -// Delete -// --------------------------------------------------------------------------- - -func TestRestorableCache_Delete_RemovesEntry(t *testing.T) { +func TestValidatingCache_LRU_Len(t *testing.T) { t.Parallel() - c := newStringCache( + c := New(5, func(_ string) (string, error) { return "v", nil }, alwaysAliveCheck, nil, ) - c.Get("k") //nolint:errcheck - - c.Delete("k") - _, ok := c.Peek("k") - assert.False(t, ok) + assert.Equal(t, 0, c.Len()) + c.Get("a") //nolint:errcheck + assert.Equal(t, 1, c.Len()) + c.Get("b") //nolint:errcheck + assert.Equal(t, 2, c.Len()) } // --------------------------------------------------------------------------- // Re-check inside singleflight (TOCTOU prevention) // --------------------------------------------------------------------------- -func TestRestorableCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) { +func TestValidatingCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) { t.Parallel() - // Simulate the TOCTOU window: a goroutine sees a cache miss, then the - // value is stored externally before it enters the singleflight group. - // The re-check inside the group should find the value and skip load. var loadCount atomic.Int32 // The load function is gated: it waits until we signal that an external - // Store has been applied, mimicking a value written by another goroutine + // Set has been applied, mimicking a value written by another goroutine // between the miss check and the singleflight group. storeApplied := make(chan struct{}) c := newStringCache( func(_ string) (string, error) { - <-storeApplied // wait until external Store is applied + <-storeApplied // wait until external Set is applied loadCount.Add(1) return "from-load", nil }, @@ -403,16 +400,14 @@ func TestRestorableCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) result string ok bool ) - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { result, ok = c.Get("k") - }() + }) - // Store the value externally to simulate a concurrent writer, then release + // Set the value externally to simulate a concurrent writer, then release // the load function. The re-check at the top of the singleflight function // fires first and finds "external-value", so load is never called. - c.Store("k", "external-value") + c.Set("k", "external-value") close(storeApplied) wg.Wait() @@ -421,11 +416,61 @@ func TestRestorableCache_Singleflight_ReCheckReturnsPreStoredValue(t *testing.T) assert.Equal(t, int32(0), loadCount.Load(), "re-check should short-circuit before load is called") } +// TestValidatingCache_Singleflight_EvictsLoserWhenLoadRacesWriter covers the +// path where load() runs to completion but loses the ContainsOrAdd race to a +// concurrent Set. The loaded-but-discarded value must be passed to onEvict so +// any resources it holds (e.g. connections) can be cleaned up. +func TestValidatingCache_Singleflight_EvictsLoserWhenLoadRacesWriter(t *testing.T) { + t.Parallel() + + // loadReached is closed when load() is about to return, giving us a hook to + // race a Set before ContainsOrAdd is called. + loadReached := make(chan struct{}) + // allowReturn lets the test control exactly when load() returns. + allowReturn := make(chan struct{}) + + var evictedKey, evictedVal string + c := newStringCache( + func(_ string) (string, error) { + close(loadReached) // signal: load has run + <-allowReturn // wait until test injects the concurrent Set + return "from-load", nil + }, + alwaysAliveCheck, + func(key, val string) { + evictedKey = key + evictedVal = val + }, + ) + + var wg sync.WaitGroup + var gotVal string + var gotOk bool + wg.Go(func() { + gotVal, gotOk = c.Get("k") + }) + + // Wait until load() is running, then inject a concurrent Set so that + // ContainsOrAdd finds the key already present and discards the loaded value. + <-loadReached + c.Set("k", "from-set") + close(allowReturn) // let load() return "from-load" + wg.Wait() + + // The concurrent Set wins: caller receives the Set value. + require.True(t, gotOk) + assert.Equal(t, "from-set", gotVal, "concurrent Set value should win") + + // The loaded-but-discarded value must be passed to onEvict. + assert.Equal(t, "k", evictedKey, "onEvict must be called for the discarded loaded value") + assert.Equal(t, "from-load", evictedVal, "onEvict must receive the discarded loaded value") +} + // --------------------------------------------------------------------------- // Singleflight deduplication // --------------------------------------------------------------------------- -func TestRestorableCache_Singleflight_DeduplicatesConcurrentMisses(t *testing.T) { +func TestValidatingCache_Singleflight_DeduplicatesConcurrentMisses(t *testing.T) { t.Parallel() const goroutines = 10 diff --git a/pkg/transport/session/session_data_storage.go b/pkg/transport/session/session_data_storage.go index 40588ea5be..9093fcdf5d 100644 --- a/pkg/transport/session/session_data_storage.go +++ b/pkg/transport/session/session_data_storage.go @@ -25,6 +25,9 @@ import ( // - Create atomically creates metadata for id only if it does not already exist. // Use this in preference to Load+Upsert to avoid TOCTOU races. // - Upsert creates or overwrites the metadata for id, refreshing the TTL. +// - Update overwrites metadata only if the key already exists (SET XX semantics). +// Returns (true, nil) if updated, (false, nil) if the session was not found. +// Use this instead of Load+Upsert to avoid TOCTOU resurrection races. // - Load retrieves metadata and refreshes the TTL (sliding-window expiry). // Returns ErrSessionNotFound if the session does not exist. // - Delete removes the session. It is not an error if the session is absent. @@ -39,6 +42,13 @@ type DataStorage interface { // Upsert creates or updates session metadata with a sliding TTL. Upsert(ctx context.Context, id string, metadata map[string]string) error + // Update overwrites session metadata only if the session ID already exists + // (conditional write, equivalent to Redis SET XX). Returns (true, nil) if + // the entry was updated, (false, nil) if it was not found, or (false, err) + // on storage errors. Use this instead of Load+Upsert to prevent resurrections + // after a concurrent Delete. + Update(ctx context.Context, id string, metadata map[string]string) (bool, error) + // Load retrieves session metadata and refreshes its TTL. // Returns ErrSessionNotFound if the session does not exist. Load(ctx context.Context, id string) (map[string]string, error) @@ -65,8 +75,9 @@ func NewLocalSessionDataStorage(ttl time.Duration) (*LocalSessionDataStorage, er return nil, fmt.Errorf("ttl must be a positive duration") } s := &LocalSessionDataStorage{ - ttl: ttl, - stopCh: make(chan struct{}), + sessions: make(map[string]*localDataEntry), + ttl: ttl, + stopCh: make(chan struct{}), } go s.cleanupRoutine() return s, nil diff --git a/pkg/transport/session/session_data_storage_local.go b/pkg/transport/session/session_data_storage_local.go index abb02c9f7d..bc125a0480 100644 --- a/pkg/transport/session/session_data_storage_local.go +++ b/pkg/transport/session/session_data_storage_local.go @@ -30,12 +30,13 @@ func (e *localDataEntry) lastAccess() time.Time { } // LocalSessionDataStorage implements DataStorage using an in-memory -// sync.Map with TTL-based eviction. +// map with TTL-based eviction. // // Sessions are evicted if they have not been accessed within the configured TTL. // A background goroutine runs until Close is called. type LocalSessionDataStorage struct { - sessions sync.Map // map[string]*localDataEntry + sessions map[string]*localDataEntry // guarded by mu + mu sync.Mutex ttl time.Duration stopCh chan struct{} stopOnce sync.Once @@ -49,9 +50,9 @@ func (s *LocalSessionDataStorage) Upsert(_ context.Context, id string, metadata if metadata == nil { metadata = make(map[string]string) } - // Store a defensive copy so callers cannot mutate stored data. - copied := maps.Clone(metadata) - s.sessions.Store(id, newLocalDataEntry(copied)) + s.mu.Lock() + s.sessions[id] = newLocalDataEntry(maps.Clone(metadata)) + s.mu.Unlock() return nil } @@ -61,26 +62,20 @@ func (s *LocalSessionDataStorage) Load(_ context.Context, id string) (map[string if id == "" { return nil, fmt.Errorf("cannot load session data with empty ID") } - - val, ok := s.sessions.Load(id) - if !ok { - return nil, ErrSessionNotFound + s.mu.Lock() + entry, ok := s.sessions[id] + if ok { + entry.lastAccessNano.Store(time.Now().UnixNano()) } - entry, ok := val.(*localDataEntry) + s.mu.Unlock() if !ok { - return nil, fmt.Errorf("invalid entry type in local session data storage") + return nil, ErrSessionNotFound } - - // Refresh last-access in place. deleteExpired re-checks the timestamp - // immediately before calling CompareAndDelete, so this atomic store is - // sufficient to prevent eviction of an actively accessed entry. - entry.lastAccessNano.Store(time.Now().UnixNano()) - return maps.Clone(entry.metadata), nil } -// Create atomically creates session metadata only if the session ID -// does not already exist. Uses sync.Map.LoadOrStore for atomicity. +// Create creates session metadata only if the session ID does not already exist. +// Returns (true, nil) if created, (false, nil) if the key already existed. func (s *LocalSessionDataStorage) Create(_ context.Context, id string, metadata map[string]string) (bool, error) { if id == "" { return false, fmt.Errorf("cannot write session data with empty ID") @@ -88,9 +83,31 @@ func (s *LocalSessionDataStorage) Create(_ context.Context, id string, metadata if metadata == nil { metadata = make(map[string]string) } - copied := maps.Clone(metadata) - _, loaded := s.sessions.LoadOrStore(id, newLocalDataEntry(copied)) - return !loaded, nil + s.mu.Lock() + defer s.mu.Unlock() + if _, exists := s.sessions[id]; exists { + return false, nil + } + s.sessions[id] = newLocalDataEntry(maps.Clone(metadata)) + return true, nil +} + +// Update overwrites session metadata only if the session ID already exists. +// Returns (true, nil) if updated, (false, nil) if not found. +func (s *LocalSessionDataStorage) Update(_ context.Context, id string, metadata map[string]string) (bool, error) { + if id == "" { + return false, fmt.Errorf("cannot write session data with empty ID") + } + if metadata == nil { + metadata = make(map[string]string) + } + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.sessions[id]; !ok { + return false, nil + } + s.sessions[id] = newLocalDataEntry(maps.Clone(metadata)) + return true, nil } // Delete removes session metadata. Not an error if absent. @@ -98,17 +115,18 @@ func (s *LocalSessionDataStorage) Delete(_ context.Context, id string) error { if id == "" { return fmt.Errorf("cannot delete session data with empty ID") } - s.sessions.Delete(id) + s.mu.Lock() + delete(s.sessions, id) + s.mu.Unlock() return nil } // Close stops the background cleanup goroutine and clears all stored metadata. func (s *LocalSessionDataStorage) Close() error { s.stopOnce.Do(func() { close(s.stopCh) }) - s.sessions.Range(func(key, _ any) bool { - s.sessions.Delete(key) - return true - }) + s.mu.Lock() + s.sessions = make(map[string]*localDataEntry) + s.mu.Unlock() return nil } @@ -140,26 +158,11 @@ func (s *LocalSessionDataStorage) cleanupRoutine() { func (s *LocalSessionDataStorage) deleteExpired() { cutoff := time.Now().Add(-s.ttl) - var toDelete []struct { - id string - entry *localDataEntry - } - s.sessions.Range(func(key, val any) bool { - entry, ok := val.(*localDataEntry) - if ok && entry.lastAccess().Before(cutoff) { - id, ok := key.(string) - if ok { - toDelete = append(toDelete, struct { - id string - entry *localDataEntry - }{id, entry}) - } - } - return true - }) - for _, item := range toDelete { - if item.entry.lastAccess().Before(cutoff) { - s.sessions.CompareAndDelete(item.id, item.entry) + s.mu.Lock() + defer s.mu.Unlock() + for id, entry := range s.sessions { + if entry.lastAccess().Before(cutoff) { + delete(s.sessions, id) } } } diff --git a/pkg/transport/session/session_data_storage_redis.go b/pkg/transport/session/session_data_storage_redis.go index 916a82050e..02d04c2027 100644 --- a/pkg/transport/session/session_data_storage_redis.go +++ b/pkg/transport/session/session_data_storage_redis.go @@ -87,6 +87,38 @@ func (s *RedisSessionDataStorage) Load(ctx context.Context, id string) (map[stri return metadata, nil } +// Update overwrites session metadata only if the key already exists. +// Uses Redis SET XX (set-if-exists) to prevent resurrecting a session that +// was deleted by a concurrent Delete call (e.g. from another pod). +// Returns (true, nil) if updated, (false, nil) if the key was not found. +func (s *RedisSessionDataStorage) Update(ctx context.Context, id string, metadata map[string]string) (bool, error) { + if id == "" { + return false, fmt.Errorf("cannot write session data with empty ID") + } + if metadata == nil { + metadata = make(map[string]string) + } + data, err := json.Marshal(metadata) + if err != nil { + return false, fmt.Errorf("failed to serialize session metadata: %w", err) + } + // Mode "XX" means "only set if the key already exists". + res, err := s.client.SetArgs(ctx, s.key(id), data, redis.SetArgs{ + Mode: "XX", + TTL: s.ttl, + }).Result() + if err != nil { + // go-redis surfaces the "key does not exist" nil bulk reply as redis.Nil. + if errors.Is(err, redis.Nil) { + return false, nil + } + return false, fmt.Errorf("failed to conditionally update session metadata: %w", err) + } + // SetArgs with Mode "XX" returns "" when the key does not exist and "OK" + // when the write succeeded. + return res == "OK", nil +} + // Create atomically creates session metadata only if the key does not // already exist. Uses Redis SET NX (set-if-not-exists) to eliminate the // TOCTOU race between Load and Upsert in multi-pod deployments. diff --git a/pkg/transport/session/session_data_storage_test.go b/pkg/transport/session/session_data_storage_test.go index 9b1be8b346..63b41f9107 100644 --- a/pkg/transport/session/session_data_storage_test.go +++ b/pkg/transport/session/session_data_storage_test.go @@ -191,6 +191,74 @@ func runDataStorageTests(t *testing.T, newStorage func(t *testing.T) DataStorage err := s.Delete(ctx, "") assert.Error(t, err) }) + + t.Run("Update overwrites existing entry and returns true", func(t *testing.T) { + t.Parallel() + s := newStorage(t) + ctx := context.Background() + + require.NoError(t, s.Upsert(ctx, "sess-update", map[string]string{"v": "original"})) + + updated, err := s.Update(ctx, "sess-update", map[string]string{"v": "updated"}) + require.NoError(t, err) + assert.True(t, updated, "should return true when key exists") + + loaded, err := s.Load(ctx, "sess-update") + require.NoError(t, err) + assert.Equal(t, "updated", loaded["v"]) + }) + + t.Run("Update on missing key returns (false, nil) without creating it", func(t *testing.T) { + t.Parallel() + s := newStorage(t) + ctx := context.Background() + + updated, err := s.Update(ctx, "sess-absent", map[string]string{"v": "new"}) + require.NoError(t, err) + assert.False(t, updated, "should return false when key does not exist") + + // The key must not have been created. + _, err = s.Load(ctx, "sess-absent") + assert.ErrorIs(t, err, ErrSessionNotFound, "Update must not create a missing key") + }) + + t.Run("Update after Delete returns (false, nil)", func(t *testing.T) { + t.Parallel() + s := newStorage(t) + ctx := context.Background() + + require.NoError(t, s.Upsert(ctx, "sess-deleted", map[string]string{"v": "1"})) + require.NoError(t, s.Delete(ctx, "sess-deleted")) + + updated, err := s.Update(ctx, "sess-deleted", map[string]string{"v": "2"}) + require.NoError(t, err) + assert.False(t, updated, "should return false after key was deleted") + }) + + t.Run("Update with empty ID returns error", func(t *testing.T) { + t.Parallel() + s := newStorage(t) + ctx := context.Background() + + _, err := s.Update(ctx, "", map[string]string{}) + assert.Error(t, err) + }) + + t.Run("Update nil metadata is treated as empty map", func(t *testing.T) { + t.Parallel() + s := newStorage(t) + ctx := context.Background() + + require.NoError(t, s.Upsert(ctx, "sess-update-nil", map[string]string{"v": "original"})) + + updated, err := s.Update(ctx, "sess-update-nil", nil) + require.NoError(t, err) + assert.True(t, updated) + + loaded, err := s.Load(ctx, "sess-update-nil") + require.NoError(t, err) + assert.NotNil(t, loaded) + }) } // --------------------------------------------------------------------------- @@ -254,9 +322,11 @@ func TestLocalSessionDataStorage(t *testing.T) { // simulating an entry that has been idle for that duration. func backdateLocalEntry(t *testing.T, s *LocalSessionDataStorage, id string, age time.Duration) { t.Helper() - val, ok := s.sessions.Load(id) + s.mu.Lock() + entry, ok := s.sessions[id] + s.mu.Unlock() require.True(t, ok, "entry %q not found for backdating", id) - val.(*localDataEntry).lastAccessNano.Store(time.Now().Add(-age).UnixNano()) + entry.lastAccessNano.Store(time.Now().Add(-age).UnixNano()) } // --------------------------------------------------------------------------- @@ -330,4 +400,24 @@ func TestRedisSessionDataStorage(t *testing.T) { require.NoError(t, err) assert.NotEmpty(t, val) }) + + t.Run("Update refreshes TTL via SET XX", func(t *testing.T) { + t.Parallel() + s, mr := newTestRedisDataStorage(t) + ctx := context.Background() + + require.NoError(t, s.Upsert(ctx, "ttl-update", map[string]string{"v": "1"})) + mr.FastForward(29 * time.Minute) + + updated, err := s.Update(ctx, "ttl-update", map[string]string{"v": "2"}) + require.NoError(t, err) + assert.True(t, updated) + + // Advance past the original TTL; Update should have reset the clock. + mr.FastForward(2 * time.Minute) + + loaded, err := s.Load(ctx, "ttl-update") + require.NoError(t, err, "session should still be alive after TTL reset by Update") + assert.Equal(t, "2", loaded["v"]) + }) } diff --git a/pkg/vmcp/server/sessionmanager/cache.go b/pkg/vmcp/server/sessionmanager/cache.go deleted file mode 100644 index 52ee95a4d5..0000000000 --- a/pkg/vmcp/server/sessionmanager/cache.go +++ /dev/null @@ -1,162 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package sessionmanager - -import ( - "errors" - "fmt" - "sync" - - "golang.org/x/sync/singleflight" -) - -// ErrExpired is returned by the check function passed to newRestorableCache to -// signal that a cached entry has definitively expired and should be evicted. -var ErrExpired = errors.New("cache entry expired") - -// errSentinelFound is returned inside the singleflight load function when a -// non-V value (e.g. terminatedSentinel) is present in the map. Returning an -// error aborts the load and causes Get to return (zero, false), consistent -// with the behaviour of the initial-hit path that also returns (zero, false) -// for non-V values. -var errSentinelFound = errors.New("sentinel stored in cache") - -// RestorableCache is a node-local write-through cache backed by a sync.Map, -// with singleflight-deduplicated restore on cache miss and lazy liveness -// validation on cache hit. -// -// Type parameter K is the key type (must be comparable). -// Type parameter V is the cached value type. -// -// Values are stored internally as any, which allows callers to place sentinel -// markers alongside V entries (e.g. a tombstone during teardown). Get performs -// a type assertion to V and treats non-V entries as "not found". Peek and -// Store expose raw any access for sentinel use. -type RestorableCache[K comparable, V any] struct { - m sync.Map - flight singleflight.Group - - // load is called on a cache miss. Return (value, nil) on success. - // A successful result is stored in the cache before being returned. - load func(key K) (V, error) - - // check is called on every cache hit to confirm liveness. Returning nil - // means the entry is alive. Returning ErrExpired means it has definitively - // expired (the entry is evicted). Any other error is treated as a transient - // failure and the cached value is returned unchanged. - check func(key K) error - - // onEvict is called after a confirmed-expired entry has been removed. The - // evicted value is passed to allow resource cleanup (e.g. closing - // connections). May be nil. - onEvict func(key K, v V) -} - -// TODO: add an age-based sweep to bound the lifetime of entries that are -// never accessed again after their storage TTL expires. The sweep would range -// over m, compare each entry's insertion time against a caller-supplied maxAge, -// and call onEvict for entries that are too old — all without touching storage. -// Until then, entries for idle sessions leak backend connections until the -// process restarts or the session ID is queried again. - -func newRestorableCache[K comparable, V any]( - load func(K) (V, error), - check func(K) error, - onEvict func(K, V), -) *RestorableCache[K, V] { - return &RestorableCache[K, V]{ - load: load, - check: check, - onEvict: onEvict, - } -} - -// Get returns the cached V value for key. -// -// On a cache hit, check is run first: ErrExpired evicts the entry and returns -// (zero, false); transient errors return the cached value unchanged. Non-V -// values stored via Store (e.g. sentinels) return (zero, false) without -// triggering a restore. -// -// On a cache miss, load is called under a singleflight group so at most one -// restore runs concurrently per key. -func (c *RestorableCache[K, V]) Get(key K) (V, bool) { - if raw, ok := c.m.Load(key); ok { - v, isV := raw.(V) - if !isV { - var zero V - return zero, false - } - if err := c.check(key); err != nil { - if errors.Is(err, ErrExpired) { - c.m.Delete(key) - if c.onEvict != nil { - c.onEvict(key, v) - } - var zero V - return zero, false - } - // Transient error — keep the cached value. - } - return v, true - } - - // Cache miss: use singleflight to prevent concurrent restores for the same key. - type result struct{ v V } - raw, err, _ := c.flight.Do(fmt.Sprint(key), func() (any, error) { - // Re-check the cache: a concurrent singleflight group may have stored - // the value between our miss check above and acquiring this group. - if stored, ok := c.m.Load(key); ok { - if v, isV := stored.(V); isV { - return result{v: v}, nil - } - // Non-V sentinel present (e.g. terminatedSentinel). Treat as a - // hard stop: do not call load() and do not overwrite the sentinel. - return nil, errSentinelFound - } - v, loadErr := c.load(key) - if loadErr != nil { - return nil, loadErr - } - // Guard against a sentinel being stored between load() completing and - // this Store call (Terminate() running concurrently). LoadOrStore is - // atomic: if a sentinel got in, we discard the freshly loaded value - // via onEvict rather than silently overwriting the sentinel. - if _, loaded := c.m.LoadOrStore(key, v); loaded { - if c.onEvict != nil { - c.onEvict(key, v) - } - return nil, errSentinelFound - } - return result{v: v}, nil - }) - if err != nil { - var zero V - return zero, false - } - r, ok := raw.(result) - return r.v, ok -} - -// Store sets key to value. value may be any type, including sentinel markers. -func (c *RestorableCache[K, V]) Store(key K, value any) { - c.m.Store(key, value) -} - -// Delete removes key from the cache. -func (c *RestorableCache[K, V]) Delete(key K) { - c.m.Delete(key) -} - -// Peek returns the raw value stored under key without type assertion, liveness -// check, or restore. Used for sentinel inspection. -func (c *RestorableCache[K, V]) Peek(key K) (any, bool) { - return c.m.Load(key) -} - -// CompareAndSwap atomically replaces the value stored under key from old to -// new. Both old and new may be any type, including sentinels. -func (c *RestorableCache[K, V]) CompareAndSwap(key K, old, replacement any) bool { - return c.m.CompareAndSwap(key, old, replacement) -} diff --git a/pkg/vmcp/server/sessionmanager/factory.go b/pkg/vmcp/server/sessionmanager/factory.go index 71fcd0c841..91dc25d8e2 100644 --- a/pkg/vmcp/server/sessionmanager/factory.go +++ b/pkg/vmcp/server/sessionmanager/factory.go @@ -32,6 +32,11 @@ import ( const instrumentationName = "github.com/stacklok/toolhive/pkg/vmcp" +// defaultCacheCapacity is the fallback used when FactoryConfig.CacheCapacity is +// zero (the Go zero value). This ensures the cache is always bounded; omitting +// CacheCapacity from a config does not silently enable unbounded growth. +const defaultCacheCapacity = 1000 + // FactoryConfig holds the session factory construction parameters that the // session manager needs to build its decorating factory. It is separate from // server.Config to avoid a circular import between the server and sessionmanager @@ -62,6 +67,13 @@ type FactoryConfig struct { // If non-nil, the optimizer factory (whether derived from OptimizerConfig or // supplied via OptimizerFactory) and workflow executors are wrapped with telemetry. TelemetryProvider *telemetry.Provider + + // CacheCapacity is the maximum number of live MultiSession entries held in + // the node-local ValidatingCache. When the cache is full the least-recently-used + // session is evicted (its backend connections are closed via onEvict). A value of + // 0 uses defaultCacheCapacity (1000). Negative values are rejected by + // sessionmanager.New. + CacheCapacity int } // resolveOptimizer wires the optimizer factory from cfg, applying telemetry diff --git a/pkg/vmcp/server/sessionmanager/session_manager.go b/pkg/vmcp/server/sessionmanager/session_manager.go index f2394ad68d..b7000c9b3f 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager.go +++ b/pkg/vmcp/server/sessionmanager/session_manager.go @@ -26,6 +26,7 @@ import ( mcpserver "github.com/mark3labs/mcp-go/server" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/cache" transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/conversion" @@ -43,13 +44,6 @@ const ( MetadataValTrue = "true" ) -// terminatedSentinel is stored in sessions when Terminate() begins tearing -// down a MultiSession. sessions.Get returns (nil, false) for sentinel entries -// (non-V values), and DecorateSession's CAS-based re-check will fail, -// preventing concurrent writers from resurrecting a storage record that -// Terminate() has already deleted. -type terminatedSentinel struct{} - // Manager bridges the domain session lifecycle (MultiSession / MultiSessionFactory) // to the mark3labs SDK's SessionIdManager interface. // @@ -74,6 +68,12 @@ type terminatedSentinel struct{} // sticky routing when session-affinity is desired. When Redis is used as the // session-storage backend the metadata is durable across pod restarts, and the // live MultiSession can be re-created via factory.RestoreSession() on a cache miss. +// +// TODO: Long-term, the cache and storage should be layered behind a single +// interface so the session manager does not need to coordinate between them. +// Reads would go through the cache (handling misses, singleflight, and liveness +// transparently); writes go to storage; caching is an implementation detail +// hidden from the caller. type Manager struct { storage transportsession.DataStorage factory vmcpsession.MultiSessionFactory @@ -84,7 +84,7 @@ type Manager struct { // (HTTP connections, routing tables). On a cache miss it restores the // session from stored metadata; on a cache hit it confirms liveness via // storage.Load, which also refreshes the Redis TTL. - sessions *RestorableCache[string, vmcpsession.MultiSession] + sessions *cache.ValidatingCache[string, vmcpsession.MultiSession] } // New creates a Manager backed by the given SessionDataStorage and backend @@ -102,6 +102,13 @@ func New( if cfg == nil || cfg.Base == nil { return nil, nil, fmt.Errorf("sessionmanager.New: FactoryConfig.Base (SessionFactory) is required") } + if cfg.CacheCapacity < 0 { + return nil, nil, fmt.Errorf("sessionmanager.New: CacheCapacity must be >= 0 (got %d)", cfg.CacheCapacity) + } + capacity := cfg.CacheCapacity + if capacity == 0 { + capacity = defaultCacheCapacity + } if len(cfg.WorkflowDefs) > 0 && cfg.ComposerFactory == nil { return nil, nil, fmt.Errorf("sessionmanager.New: ComposerFactory is required when WorkflowDefs are provided") } @@ -135,7 +142,8 @@ func New( backendReg: backendRegistry, } - sm.sessions = newRestorableCache( + sm.sessions = cache.New( + capacity, sm.loadSession, sm.checkSession, func(id string, sess vmcpsession.MultiSession) { @@ -143,7 +151,7 @@ func New( slog.Warn("session cache: error closing evicted session", "session_id", id, "error", closeErr) } - slog.Warn("session cache: evicted expired session from node-local cache", + slog.Warn("session cache: session evicted from node-local cache", "session_id", id) }, ) @@ -343,16 +351,29 @@ func (sm *Manager) CreateSession( // Persist the serialisable session metadata to the pluggable backend (e.g. // Redis) so that Validate() and TTL management work correctly. The live // MultiSession itself is cached in the node-local multiSessions map below. + // + // Use Update (SET XX) rather than Upsert to close the TOCTOU window between + // the second placeholder check above and this write. If Terminate deleted the + // key in that window, Update returns (false, nil) and we bail without + // resurrecting the deleted session. storeCtx, storeCancel := context.WithTimeout(ctx, createSessionStorageTimeout) defer storeCancel() - if err := sm.storage.Upsert(storeCtx, sessionID, sess.GetMetadata()); err != nil { + stored, err := sm.storage.Update(storeCtx, sessionID, sess.GetMetadata()) + if err != nil { _ = sess.Close() sm.cleanupFailedPlaceholder(sessionID, placeholder2) return nil, fmt.Errorf("Manager.CreateSession: failed to store session metadata: %w", err) } + if !stored { + _ = sess.Close() + return nil, fmt.Errorf( + "Manager.CreateSession: session %q was terminated between placeholder check and metadata store", + sessionID, + ) + } // Cache the live MultiSession so that GetMultiSession can retrieve it. - sm.sessions.Store(sessionID, sess) + sm.sessions.Set(sessionID, sess) slog.Debug("Manager: created multi-session", "session_id", sessionID, @@ -366,13 +387,21 @@ func (sm *Manager) CreateSession( // as a valid session), and prevents repeated Validate() calls from refreshing // the Redis TTL and keeping the placeholder alive indefinitely. // +// Uses Update (SET XX) so that a Terminate() that already deleted the key is +// not inadvertently resurrected as a terminated entry. +// // Cleanup is best-effort: errors are logged but not returned, since the caller // already has an error to report. func (sm *Manager) cleanupFailedPlaceholder(sessionID string, metadata map[string]string) { - metadata[MetadataKeyTerminated] = MetadataValTrue + // Copy before mutating so the caller's map is not modified. + terminated := make(map[string]string, len(metadata)+1) + for k, v := range metadata { + terminated[k] = v + } + terminated[MetadataKeyTerminated] = MetadataValTrue cleanupCtx, cancel := context.WithTimeout(context.Background(), createSessionStorageTimeout) defer cancel() - if err := sm.storage.Upsert(cleanupCtx, sessionID, metadata); err != nil { + if _, err := sm.storage.Update(cleanupCtx, sessionID, terminated); err != nil { slog.Warn("Manager.CreateSession: failed to mark failed placeholder as terminated; it will linger until TTL expires", "session_id", sessionID, "error", err) } @@ -415,11 +444,10 @@ func (sm *Manager) Validate(sessionID string) (isTerminated bool, err error) { // where client termination during the Phase 1→Phase 2 window could resurrect // sessions with open backend connections: // -// - MultiSession (Phase 2): Close() releases backend connections, then the -// session is deleted from storage immediately. After deletion Validate() -// returns (false, error) — the same response as "never existed". This is -// intentional: a terminated MultiSession has no resources to preserve, so -// immediate removal is cleaner than marking and waiting for TTL. +// - MultiSession (Phase 2): the storage key is deleted. The node-local cache +// self-heals on the next Get: checkSession detects ErrSessionNotFound, +// evicts the entry, and onEvict closes backend connections. After deletion +// Validate() returns (false, error) — the same response as "never existed". // // - Placeholder (Phase 1): the session is marked terminated=true and left // for TTL cleanup. This prevents CreateSession() from opening backend @@ -438,46 +466,10 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { ctx, cancel := context.WithTimeout(context.Background(), terminateTimeout) defer cancel() - // Check the node-local cache first: a fully-formed MultiSession is stored - // here while this pod owns it. - if v, ok := sm.sessions.Peek(sessionID); ok { - // A terminatedSentinel means another goroutine is already tearing down - // this session. Do not fall through to the placeholder path — that would - // race with the concurrent Terminate's storage.Delete and potentially - // recreate the storage record after it was deleted. - if _, isSentinel := v.(terminatedSentinel); isSentinel { - slog.Debug("Manager.Terminate: concurrent termination in progress, skipping", - "session_id", sessionID) - return false, nil - } - if multiSess, ok := v.(vmcpsession.MultiSession); ok { - // Publish the tombstone before deleting from storage. Any concurrent - // GetMultiSession call will see the terminatedSentinel and return - // (nil, false), and DecorateSession's CAS-based re-check will fail, - // preventing both from recreating the storage record after we delete it. - sm.sessions.Store(sessionID, terminatedSentinel{}) - - if deleteErr := sm.storage.Delete(ctx, sessionID); deleteErr != nil { - // Rollback: restore the live session so the caller can retry. - sm.sessions.Store(sessionID, multiSess) - return false, fmt.Errorf("Manager.Terminate: failed to delete session from storage: %w", deleteErr) - } - - // Storage is clean; remove the sentinel and release backend connections. - sm.sessions.Delete(sessionID) - if closeErr := multiSess.Close(); closeErr != nil { - slog.Warn("Manager.Terminate: error closing multi-session backend connections", - "session_id", sessionID, "error", closeErr) - } - slog.Info("Manager.Terminate: session terminated", "session_id", sessionID) - return false, nil - } - } - - // No MultiSession in the local map — treat as a placeholder session. - // Load current metadata, mark as terminated, and store back. + // Load current metadata to determine session phase. metadata, loadErr := sm.storage.Load(ctx, sessionID) if errors.Is(loadErr, transportsession.ErrSessionNotFound) { + // Already gone (concurrent termination or TTL expiry). slog.Debug("Manager.Terminate: session not found (already expired?)", "session_id", sessionID) return false, nil } @@ -485,36 +477,39 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { return false, fmt.Errorf("Manager.Terminate: failed to load session %q: %w", sessionID, loadErr) } - // Placeholder session (not yet upgraded to MultiSession). - // - // This handles the race condition where a client sends DELETE between - // Generate() (Phase 1) and CreateSession() (Phase 2). The two-phase - // pattern creates a window where the session exists as a placeholder: - // - // 1. Client sends initialize → Generate() creates placeholder - // 2. Client sends DELETE before OnRegisterSession hook fires - // 3. We mark the placeholder as terminated (don't delete it) - // 4. CreateSession() hook fires → sees terminated flag → fails fast - // - // Without this branch, CreateSession() would open backend HTTP connections - // for a session the client already terminated, silently resurrecting it. - // - // We mark (not delete) so Validate() can return isTerminated=true, which - // lets the SDK distinguish "actively terminated" from "never existed". - // TTL cleanup will remove the placeholder later. + if _, isFullSession := metadata[sessiontypes.MetadataKeyTokenHash]; isFullSession { + // Phase 2 (full MultiSession): delete from storage, then evict from the + // node-local cache so onEvict closes backend connections immediately rather + // than waiting for the next Get or an LRU eviction. + if deleteErr := sm.storage.Delete(ctx, sessionID); deleteErr != nil { + return false, fmt.Errorf("Manager.Terminate: failed to delete session from storage: %w", deleteErr) + } + sm.sessions.Remove(sessionID) + slog.Info("Manager.Terminate: session terminated", "session_id", sessionID) + return false, nil + } + + // Phase 1 (placeholder): mark terminated so CreateSession fast-fails and + // Validate returns isTerminated=true during the TTL window. + // Use Update (SET XX) rather than Upsert so we never resurrect a key that + // was concurrently deleted or expired between the Load above and this write. + // (false, nil) means already gone — treat as success. metadata[MetadataKeyTerminated] = MetadataValTrue - if storeErr := sm.storage.Upsert(ctx, sessionID, metadata); storeErr != nil { + updated, storeErr := sm.storage.Update(ctx, sessionID, metadata) + if storeErr != nil { slog.Warn("Manager.Terminate: failed to persist terminated flag for placeholder; attempting delete fallback", "session_id", sessionID, "error", storeErr) - // Use a fresh context: if ctx expired (deadline exceeded), the same - // context would cause the fallback delete to fail immediately too. deleteCtx, deleteCancel := context.WithTimeout(context.Background(), terminateTimeout) - defer deleteCancel() if deleteErr := sm.storage.Delete(deleteCtx, sessionID); deleteErr != nil { + deleteCancel() return false, fmt.Errorf( "Manager.Terminate: failed to persist terminated flag and delete placeholder: storeErr=%v, deleteErr=%w", storeErr, deleteErr) } + deleteCancel() + } else if !updated { + // Session expired or was concurrently deleted between Load and Update — already gone. + slog.Debug("Manager.Terminate: placeholder already gone before terminated flag could be set", "session_id", sessionID) } slog.Info("Manager.Terminate: session terminated", "session_id", sessionID) @@ -527,13 +522,15 @@ func (sm *Manager) Terminate(sessionID string) (isNotAllowed bool, err error) { // cross-pod RestoreSession call does not attempt to reconnect to the expired // backend session. // -// After a successful storage update the session is evicted from the node-local -// cache; the next GetMultiSession call triggers RestoreSession with the updated -// metadata, discarding the stale in-memory copy. +// After a successful storage update, the cached entry is not immediately evicted. +// On the next GetMultiSession call, checkSession detects that the stored +// MetadataKeyBackendIDs differs from the cached session's value, evicts the stale +// entry via onEvict, and triggers RestoreSession with the updated metadata. +// On storage error, no eviction occurs and the caller retries on the next access. // // This is a best-effort operation. If the session is absent from storage (not // found or terminated) the call is a silent no-op. Storage errors are logged -// but not returned; on error the cache is not evicted. +// but not returned. func (sm *Manager) NotifyBackendExpired(sessionID, workloadID string) { loadCtx, loadCancel := context.WithTimeout(context.Background(), notifyBackendExpiredTimeout) defer loadCancel() @@ -584,50 +581,27 @@ func (sm *Manager) NotifyBackendExpired(sessionID, workloadID string) { } } -// updateMetadata writes a complete metadata snapshot to storage and evicts the -// session from the node-local cache so the next GetMultiSession call triggers a -// fresh RestoreSession with the updated state. -// -// Cross-pod TOCTOU: a re-check Load is performed immediately before the Upsert -// to detect cross-pod session termination (where another pod calls -// storage.Delete). If the key is absent at re-check time we bail without -// upserting. A residual race remains between the re-check and the Upsert (a -// concurrent pod could delete the key in that window), but the window is now -// microseconds rather than the full NotifyBackendExpired span. Closing the race -// entirely would require a conditional write primitive (e.g. Redis SET XX / -// UpsertIfPresent) added to the DataStorage interface. -// -// NOTE: concurrent calls for the same session are last-write-wins. We assume -// parallel metadata writers within a session do not occur; NotifyBackendExpired -// is the only post-creation writer and backend expiry events are serialised by -// the backend registry. This can be retrofitted with CAS semantics or a version -// counter if that assumption changes. +// updateMetadata writes a complete metadata snapshot to storage using a +// conditional Update (SET XX). If the key is absent at update time (concurrent +// Delete), the call is a no-op. The cache self-heals on the next GetMultiSession +// call: checkSession detects metadata drift, evicts the stale entry, and +// RestoreSession reloads with fresh state. func (sm *Manager) updateMetadata(sessionID string, metadata map[string]string) error { - // Same-pod guard: if Terminate() is already tearing down this session on - // this pod the sentinel is in the cache and storage is already deleted. - if raw, ok := sm.sessions.Peek(sessionID); ok { - if _, isSentinel := raw.(terminatedSentinel); isSentinel { - return nil - } - } - ctx, cancel := context.WithTimeout(context.Background(), notifyBackendExpiredTimeout) defer cancel() - // Cross-pod guard: re-check that the storage record still exists before - // upserting. If another pod terminated the session (deleting the key) after - // NotifyBackendExpired's initial Load, we must not recreate the record. - if _, err := sm.storage.Load(ctx, sessionID); err != nil { - if errors.Is(err, transportsession.ErrSessionNotFound) { - return nil // session was terminated elsewhere; nothing to update - } + // Update only succeeds if the key still exists. A concurrent Delete (same + // pod or cross-pod) returns (false, nil), and we bail without resurrecting. + updated, err := sm.storage.Update(ctx, sessionID, metadata) + if err != nil { return err } - - if err := sm.storage.Upsert(ctx, sessionID, metadata); err != nil { - return err + if !updated { + return nil // session was terminated; nothing to update } - sm.sessions.Delete(sessionID) + // The cache self-heals lazily: on the next GetMultiSession, checkSession detects + // either the absent storage key or stale MetadataKeyBackendIDs and evicts the + // entry, triggering a fresh RestoreSession. return nil } @@ -664,32 +638,26 @@ func (sm *Manager) GetMultiSession(sessionID string) (vmcpsession.MultiSession, // replacing the old session and its backend connections. This ensures that a // backend-expiry update written by pod A propagates to pod B on the next // cache access rather than waiting for natural TTL expiry. -func (sm *Manager) checkSession(sessionID string) error { +func (sm *Manager) checkSession(sessionID string, sess vmcpsession.MultiSession) error { checkCtx, cancel := context.WithTimeout(context.Background(), restoreStorageTimeout) defer cancel() metadata, err := sm.storage.Load(checkCtx, sessionID) if errors.Is(err, transportsession.ErrSessionNotFound) { - return ErrExpired + return cache.ErrExpired } if err != nil { return err // transient storage error — keep cached } if metadata[MetadataKeyTerminated] == MetadataValTrue { - return ErrExpired - } - - // If the cached session has backend metadata and it differs from storage, - // evict to pick up the update. Only compare when the cached session - // explicitly carries MetadataKeyBackendIDs to avoid spurious evictions for - // sessions whose in-memory representation does not track backend IDs (e.g. - // test mocks that return an empty metadata map). - if raw, ok := sm.sessions.Peek(sessionID); ok { - if sess, ok := raw.(vmcpsession.MultiSession); ok { - if cachedIDs, present := sess.GetMetadata()[vmcpsession.MetadataKeyBackendIDs]; present { - if cachedIDs != metadata[vmcpsession.MetadataKeyBackendIDs] { - return ErrExpired - } - } + return cache.ErrExpired + } + + // Compare backend IDs to detect cross-pod metadata drift. + // Only compare when the cached session carries MetadataKeyBackendIDs to + // avoid spurious evictions for sessions that don't track backend IDs. + if cachedIDs, present := sess.GetMetadata()[vmcpsession.MetadataKeyBackendIDs]; present { + if cachedIDs != metadata[vmcpsession.MetadataKeyBackendIDs] { + return cache.ErrExpired } } @@ -747,14 +715,9 @@ func (sm *Manager) loadSession(sessionID string) (vmcpsession.MultiSession, erro // and stores the result back. Returns an error if the session is not found or // has not yet been upgraded from placeholder to MultiSession. // -// A re-check is performed immediately before storing to guard against a -// race with Terminate(): if the session is deleted between GetMultiSession and -// the store, the store would silently resurrect a terminated session. -// The re-check catches that window. A narrow TOCTOU gap remains between the -// re-check and the store, but its consequence is bounded: Terminate() already -// called Close() on the underlying MultiSession before deleting it, so any -// resurrected decorator wraps an already-closed session and will fail on first -// use rather than leaking backend connections. +// storage.Update is the concurrency guard. If it returns (false, nil), the +// session was deleted; the cache entry will be evicted on the next Get when +// checkSession detects ErrSessionNotFound. func (sm *Manager) DecorateSession(sessionID string, fn func(sessiontypes.MultiSession) sessiontypes.MultiSession) error { sess, ok := sm.GetMultiSession(sessionID) if !ok { @@ -767,24 +730,24 @@ func (sm *Manager) DecorateSession(sessionID string, fn func(sessiontypes.MultiS if decorated.ID() != sessionID { return fmt.Errorf("DecorateSession: decorator changed session ID from %q to %q", sessionID, decorated.ID()) } - // Atomically replace the original entry with the decorated one. - // If Terminate() has stored a terminatedSentinel between the first - // GetMultiSession call above and here, CompareAndSwap returns false and - // we bail out before touching storage — preventing resurrection of a - // terminated session's storage record. - if !sm.sessions.CompareAndSwap(sessionID, sess, decorated) { - return fmt.Errorf("DecorateSession: session %q was terminated or concurrently modified during decoration", sessionID) - } - // Persist updated metadata to storage. On failure, attempt to rollback - // the local-map entry so the caller can retry. If Terminate() has since - // replaced the decorated entry with a sentinel, the rollback CAS returns - // false and we leave the sentinel in place. + + // Persist metadata to storage first via conditional Update (SET XX). + // Only update the node-local cache after a successful write so that a + // storage error or a concurrent delete never leaves a decorated (but + // unpersisted) value in the cache where retries could stack decorations. decorateCtx, decorateCancel := context.WithTimeout(context.Background(), decorateTimeout) defer decorateCancel() - if err := sm.storage.Upsert(decorateCtx, sessionID, decorated.GetMetadata()); err != nil { - _ = sm.sessions.CompareAndSwap(sessionID, decorated, sess) + updated, err := sm.storage.Update(decorateCtx, sessionID, decorated.GetMetadata()) + if err != nil { return fmt.Errorf("DecorateSession: failed to store decorated session metadata: %w", err) } + if !updated { + // Session was deleted (by Terminate or TTL) between Get and Update. + // Evict the stale cache entry so onEvict closes backend connections. + sm.sessions.Remove(sessionID) + return fmt.Errorf("DecorateSession: session %q was deleted during decoration", sessionID) + } + sm.sessions.Set(sessionID, decorated) return nil } diff --git a/pkg/vmcp/server/sessionmanager/session_manager_test.go b/pkg/vmcp/server/sessionmanager/session_manager_test.go index 8728bf9ecc..4c5a8a25bd 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager_test.go +++ b/pkg/vmcp/server/sessionmanager/session_manager_test.go @@ -16,6 +16,7 @@ import ( "go.uber.org/mock/gomock" "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/cache" transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" @@ -92,6 +93,9 @@ func (alwaysFailDataStorage) Load(_ context.Context, _ string) (map[string]strin func (alwaysFailDataStorage) Create(_ context.Context, _ string, _ map[string]string) (bool, error) { return false, errors.New("storage unavailable") } +func (alwaysFailDataStorage) Update(_ context.Context, _ string, _ map[string]string) (bool, error) { + return false, errors.New("storage unavailable") +} func (alwaysFailDataStorage) Delete(_ context.Context, _ string) error { return nil } func (alwaysFailDataStorage) Close() error { return nil } @@ -176,7 +180,7 @@ func newTestSessionManager( ) (*Manager, transportsession.DataStorage) { t.Helper() storage := newTestSessionDataStorage(t) - sm, cleanup, err := New(storage, &FactoryConfig{Base: factory}, registry) + sm, cleanup, err := New(storage, &FactoryConfig{Base: factory, CacheCapacity: 1000}, registry) require.NoError(t, err) t.Cleanup(func() { _ = cleanup(context.Background()) }) return sm, storage @@ -216,7 +220,7 @@ func TestSessionManager_Generate(t *testing.T) { ctrl := gomock.NewController(t) sess := newMockSession(t, ctrl, "placeholder", nil) factory := newMockFactory(t, ctrl, sess) - sm, cleanup, err := New(alwaysFailDataStorage{}, &FactoryConfig{Base: factory}, newFakeRegistry()) + sm, cleanup, err := New(alwaysFailDataStorage{}, &FactoryConfig{Base: factory, CacheCapacity: 1000}, newFakeRegistry()) require.NoError(t, err) t.Cleanup(func() { _ = cleanup(context.Background()) }) @@ -573,7 +577,8 @@ func TestSessionManager_Terminate(t *testing.T) { MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { createdSess = newMockSession(t, ctrl, id, tools) - // Close() will be called exactly once during Terminate + // Close() is called eagerly by onEvict when Terminate removes + // the entry from the node-local cache after storage.Delete. createdSess.EXPECT().Close().Return(nil).Times(1) return createdSess, nil }).Times(1) @@ -589,7 +594,7 @@ func TestSessionManager_Terminate(t *testing.T) { require.NoError(t, err) require.NotNil(t, createdSess) - // Terminate should close the backend connections. + // Terminate deletes from storage and removes from cache; onEvict fires Close(). isNotAllowed, err := sm.Terminate(sessionID) require.NoError(t, err) assert.False(t, isNotAllowed) @@ -605,7 +610,8 @@ func TestSessionManager_Terminate(t *testing.T) { MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { sess := newMockSession(t, ctrl, id, nil) - sess.EXPECT().Close().Return(nil).Times(1) + // Close is called by onEvict when Terminate removes the cache entry. + sess.EXPECT().Close().Return(nil).AnyTimes() return sess, nil }).Times(1) @@ -618,6 +624,12 @@ func TestSessionManager_Terminate(t *testing.T) { _, err := sm.CreateSession(context.Background(), sessionID) require.NoError(t, err) + // Seed MetadataKeyTokenHash into storage so Terminate recognises this + // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. + require.NoError(t, storage.Upsert(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyTokenHash: "", + })) + // Session must exist before termination. _, loadErr := storage.Load(context.Background(), sessionID) assert.NoError(t, loadErr, "session should exist in storage before Terminate") @@ -673,7 +685,7 @@ func TestSessionManager_Terminate(t *testing.T) { failStoreAfter: 1, // fail after 1 successful call (Generate's Create) failDelete: false, } - sm, cleanup, err := New(failingStorage, &FactoryConfig{Base: factory}, registry) + sm, cleanup, err := New(failingStorage, &FactoryConfig{Base: factory, CacheCapacity: 1000}, registry) require.NoError(t, err) t.Cleanup(func() { _ = cleanup(context.Background()) }) @@ -711,7 +723,7 @@ func TestSessionManager_Terminate(t *testing.T) { failStoreAfter: 1, // fail after 1 successful call (Generate's Create) failDelete: true, } - sm, cleanup, err := New(failingStorage, &FactoryConfig{Base: factory}, registry) + sm, cleanup, err := New(failingStorage, &FactoryConfig{Base: factory, CacheCapacity: 1000}, registry) require.NoError(t, err) t.Cleanup(func() { _ = cleanup(context.Background()) }) @@ -1875,20 +1887,26 @@ func TestSessionManager_DecorateSession(t *testing.T) { return sess, nil }).Times(1) - sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) + sm, storage := newTestSessionManager(t, factory, newFakeRegistry()) sessionID := sm.Generate() require.NotEmpty(t, sessionID) _, err := sm.CreateSession(context.Background(), sessionID) require.NoError(t, err) + // Seed MetadataKeyTokenHash into storage so Terminate recognises this + // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. + require.NoError(t, storage.Upsert(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyTokenHash: "", + })) + err = sm.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { // Simulate concurrent Terminate() completing during decoration. _, _ = sm.Terminate(sessionID) return sess }) require.Error(t, err) - assert.Contains(t, err.Error(), "was terminated or concurrently modified during decoration") + assert.Contains(t, err.Error(), "was deleted during decoration") // The session must not be resurrected. _, ok := sm.GetMultiSession(sessionID) @@ -1916,13 +1934,21 @@ func TestSessionManager_CheckSession(t *testing.T) { return f } + makeEmptySess := func(t *testing.T) vmcpsession.MultiSession { + t.Helper() + ctrl := gomock.NewController(t) + m := sessionmocks.NewMockMultiSession(ctrl) + m.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() + return m + } + t.Run("alive session returns nil", func(t *testing.T) { t.Parallel() sm, storage := newTestSessionManager(t, makeFactory(t), newFakeRegistry()) sessionID := "alive-session" require.NoError(t, storage.Upsert(context.Background(), sessionID, map[string]string{})) - err := sm.checkSession(sessionID) + err := sm.checkSession(sessionID, makeEmptySess(t)) assert.NoError(t, err, "alive session must return nil") }) @@ -1930,8 +1956,8 @@ func TestSessionManager_CheckSession(t *testing.T) { t.Parallel() sm, _ := newTestSessionManager(t, makeFactory(t), newFakeRegistry()) - err := sm.checkSession("nonexistent-session") - assert.ErrorIs(t, err, ErrExpired, "deleted session must return ErrExpired") + err := sm.checkSession("nonexistent-session", makeEmptySess(t)) + assert.ErrorIs(t, err, cache.ErrExpired, "deleted session must return ErrExpired") }) t.Run("terminated session returns ErrExpired", func(t *testing.T) { @@ -1945,8 +1971,8 @@ func TestSessionManager_CheckSession(t *testing.T) { MetadataKeyTerminated: MetadataValTrue, })) - err := sm.checkSession(sessionID) - assert.ErrorIs(t, err, ErrExpired, "terminated session must return ErrExpired") + err := sm.checkSession(sessionID, makeEmptySess(t)) + assert.ErrorIs(t, err, cache.ErrExpired, "terminated session must return ErrExpired") }) t.Run("stale backend list triggers cross-pod eviction", func(t *testing.T) { @@ -1970,10 +1996,10 @@ func TestSessionManager_CheckSession(t *testing.T) { cached.EXPECT().GetMetadata().Return(map[string]string{ vmcpsession.MetadataKeyBackendIDs: "backend-a,backend-b", }).AnyTimes() - sm.sessions.Store(sessionID, cached) + sm.sessions.Set(sessionID, cached) - err := sm.checkSession(sessionID) - assert.ErrorIs(t, err, ErrExpired, + err := sm.checkSession(sessionID, cached) + assert.ErrorIs(t, err, cache.ErrExpired, "stale backend list must return ErrExpired to trigger cross-pod eviction") }) @@ -1991,9 +2017,9 @@ func TestSessionManager_CheckSession(t *testing.T) { cached.EXPECT().GetMetadata().Return(map[string]string{ vmcpsession.MetadataKeyBackendIDs: "backend-a", }).AnyTimes() - sm.sessions.Store(sessionID, cached) + sm.sessions.Set(sessionID, cached) - err := sm.checkSession(sessionID) + err := sm.checkSession(sessionID, cached) assert.NoError(t, err, "matching backend list must return nil") }) @@ -2011,9 +2037,9 @@ func TestSessionManager_CheckSession(t *testing.T) { ctrl := gomock.NewController(t) cached := sessionmocks.NewMockMultiSession(ctrl) cached.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() - sm.sessions.Store(sessionID, cached) + sm.sessions.Set(sessionID, cached) - err := sm.checkSession(sessionID) + err := sm.checkSession(sessionID, cached) assert.NoError(t, err, "absent MetadataKeyBackendIDs in cache must not cause eviction") }) } @@ -2166,6 +2192,12 @@ func TestNotifyBackendExpired(t *testing.T) { _, err := sm.CreateSession(t.Context(), sessionID) require.NoError(t, err) + // Seed MetadataKeyTokenHash into storage so Terminate recognises this + // as a Phase 2 (full MultiSession) and deletes rather than marks terminated. + require.NoError(t, storage.Upsert(context.Background(), sessionID, map[string]string{ + sessiontypes.MetadataKeyTokenHash: "", + })) + _, err = sm.Terminate(sessionID) require.NoError(t, err) @@ -2177,9 +2209,13 @@ func TestNotifyBackendExpired(t *testing.T) { "terminated session must not be resurrected by NotifyBackendExpired") }) - t.Run("concurrent termination: sentinel prevents resurrection after Load succeeds", func(t *testing.T) { + t.Run("same-pod termination: storage.Update returns false, no resurrection", func(t *testing.T) { t.Parallel() + // Verify that updateMetadata's storage.Update (SET XX) prevents + // resurrection even when Terminate runs concurrently on the same pod. + // We model Terminate completing (key deleted) before updateMetadata + // reaches its storage.Update call. ctrl := gomock.NewController(t) registry := newFakeRegistry() sess := newMockSession(t, ctrl, "s", nil) @@ -2196,22 +2232,17 @@ func TestNotifyBackendExpired(t *testing.T) { map[string]string{"workload-a": "sess-a"}, ) - // Simulate Terminate-in-progress: inject the terminatedSentinel directly - // into the node-local cache (as Terminate does before calling - // storage.Delete) while leaving storage intact. This models the TOCTOU - // window where NotifyBackendExpired's Load succeeded before Terminate's - // storage.Delete ran but our sentinel check runs while the sentinel is - // still present. - sm.sessions.Store(sessionID, terminatedSentinel{}) + // Simulate Terminate having completed its storage.Delete already. + require.NoError(t, storage.Delete(context.Background(), sessionID)) - // NotifyBackendExpired must detect the terminatedSentinel and bail - // before Upsert, leaving the storage record unmodified. + // storage.Update (SET XX) in updateMetadata returns (false, nil) because + // the key no longer exists — NotifyBackendExpired must bail without + // recreating the record. sm.NotifyBackendExpired(sessionID, "workload-a") - got, loadErr := storage.Load(context.Background(), sessionID) - require.NoError(t, loadErr) - assert.Equal(t, "workload-a", got[vmcpsession.MetadataKeyBackendIDs], - "storage must not be modified when terminatedSentinel is present") + _, loadErr := storage.Load(context.Background(), sessionID) + assert.ErrorIs(t, loadErr, transportsession.ErrSessionNotFound, + "NotifyBackendExpired must not resurrect a session whose storage key was deleted by Terminate") }) t.Run("cross-pod termination: absent storage key is a no-op (no resurrection)", func(t *testing.T) { @@ -2247,7 +2278,7 @@ func TestNotifyBackendExpired(t *testing.T) { "NotifyBackendExpired must not resurrect a session terminated by another pod") }) - t.Run("evicts session from node-local cache on success", func(t *testing.T) { + t.Run("lazy eviction: session stays in cache immediately after NotifyBackendExpired", func(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -2261,9 +2292,8 @@ func TestNotifyBackendExpired(t *testing.T) { _, err := sm.CreateSession(t.Context(), sessionID) require.NoError(t, err) - // CreateSession must have populated the node-local cache. - _, cached := sm.sessions.Peek(sessionID) - require.True(t, cached, "session must be in node-local cache after CreateSession") + // Session must be in cache after CreateSession. + assert.Equal(t, 1, sm.sessions.Len(), "session must be in node-local cache after CreateSession") seedBackendMetadata(t, storage, sessionID, []string{"workload-a"}, @@ -2272,11 +2302,10 @@ func TestNotifyBackendExpired(t *testing.T) { sm.NotifyBackendExpired(sessionID, "workload-a") - // The session must have been evicted so the next GetMultiSession call - // triggers RestoreSession with the updated (backend-free) metadata. - _, stillCached := sm.sessions.Peek(sessionID) - assert.False(t, stillCached, - "session must be evicted from node-local cache after NotifyBackendExpired") + // With lazy eviction, session is still in cache immediately after NotifyBackendExpired. + // checkSession detects drift on the next GetMultiSession call. + assert.Equal(t, 1, sm.sessions.Len(), + "session must still be in cache immediately after NotifyBackendExpired (eviction is lazy)") }) }