diff --git a/pkg/transport/proxy/transparent/backend_recovery_test.go b/pkg/transport/proxy/transparent/backend_recovery_test.go new file mode 100644 index 0000000000..dc51e9405e --- /dev/null +++ b/pkg/transport/proxy/transparent/backend_recovery_test.go @@ -0,0 +1,277 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package transparent + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/transport/session" +) + +// stubSessionStore is a minimal in-memory recoverySessionStore for unit tests. +type stubSessionStore struct { + sessions map[string]session.Session +} + +func newStubStore(sessions ...session.Session) *stubSessionStore { + m := make(map[string]session.Session) + for _, s := range sessions { + m[s.ID()] = s + } + return &stubSessionStore{sessions: m} +} + +func (s *stubSessionStore) Get(id string) (session.Session, bool) { + sess, ok := s.sessions[id] + return sess, ok +} + +func (s *stubSessionStore) UpsertSession(sess session.Session) error { + s.sessions[sess.ID()] = sess + return nil +} + +// newRecovery builds a backendRecovery backed by the given store and forward func. +func newRecovery(targetURL string, store recoverySessionStore, fwd func(*http.Request) (*http.Response, error)) *backendRecovery { + return &backendRecovery{ + targetURI: targetURL, + forward: fwd, + sessions: store, + } +} + +// TestBackendRecoveryNoSession verifies that reinitializeAndReplay returns +// (nil, nil) when the request carries no Mcp-Session-Id. +func TestBackendRecoveryNoSession(t *testing.T) { + t.Parallel() + + r := newRecovery("http://cluster-ip:8080", newStubStore(), nil) + req, err := http.NewRequest(http.MethodPost, "http://cluster-ip:8080/mcp", + strings.NewReader(`{"method":"tools/list"}`)) + require.NoError(t, err) + + resp, err := r.reinitializeAndReplay(req, nil) + assert.Nil(t, resp) + assert.NoError(t, err) +} + +// TestBackendRecoveryUnknownSession verifies that reinitializeAndReplay returns +// (nil, nil) when the session ID is not in the store. +func TestBackendRecoveryUnknownSession(t *testing.T) { + t.Parallel() + + r := newRecovery("http://cluster-ip:8080", newStubStore(), nil) + req, err := http.NewRequest(http.MethodPost, "http://cluster-ip:8080/mcp", + strings.NewReader(`{"method":"tools/list"}`)) + require.NoError(t, err) + req.Header.Set("Mcp-Session-Id", uuid.New().String()) + + resp, err := r.reinitializeAndReplay(req, nil) + assert.Nil(t, resp) + assert.NoError(t, err) +} + +// TestBackendRecoveryNoInitBody verifies that when the session has no stored +// init body, reinitializeAndReplay resets backend_url to the ClusterIP and +// returns (nil, nil) so the caller falls through to a 404 the client can handle. +func TestBackendRecoveryNoInitBody(t *testing.T) { + t.Parallel() + + const clusterIP = "http://cluster-ip:8080" + clientSID := uuid.New().String() + sess := session.NewProxySession(clientSID) + sess.SetMetadata(sessionMetadataBackendURL, "http://10.0.0.5:8080") // stale pod IP + store := newStubStore(sess) + + r := newRecovery(clusterIP, store, nil) + req, err := http.NewRequest(http.MethodPost, clusterIP+"/mcp", + strings.NewReader(`{"method":"tools/list"}`)) + require.NoError(t, err) + req.Header.Set("Mcp-Session-Id", clientSID) + + resp, err := r.reinitializeAndReplay(req, nil) + assert.Nil(t, resp) + assert.NoError(t, err) + + // backend_url should be reset to ClusterIP so the next request routes correctly. + updated, ok := store.Get(clientSID) + require.True(t, ok) + backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL) + require.True(t, exists) + assert.Equal(t, clusterIP, backendURL, "backend_url should be reset to ClusterIP when no init body") +} + +// TestBackendRecoveryHappyPath verifies the full re-init flow: the stored +// initialize body is replayed to the ClusterIP, the new backend session ID is +// captured, the session is updated, and the original request is replayed — all +// without standing up a full TransparentProxy. +func TestBackendRecoveryHappyPath(t *testing.T) { + t.Parallel() + + const initBody = `{"jsonrpc":"2.0","id":1,"method":"initialize"}` + newBackendSID := uuid.New().String() + var ( + forwardMu sync.Mutex + forwardCalls []string + ) + + // Backend: returns a session ID on initialize, 200 otherwise. + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + forwardMu.Lock() + forwardCalls = append(forwardCalls, r.Header.Get("Mcp-Session-Id")) + forwardMu.Unlock() + if strings.Contains(string(body), `"initialize"`) { + w.Header().Set("Mcp-Session-Id", newBackendSID) + } + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + clientSID := uuid.New().String() + sess := session.NewProxySession(clientSID) + sess.SetMetadata(sessionMetadataInitBody, initBody) + store := newStubStore(sess) + + r := newRecovery(backend.URL, store, http.DefaultTransport.RoundTrip) + + origBody := []byte(`{"method":"tools/list"}`) + req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp", + bytes.NewReader(origBody)) + require.NoError(t, err) + req.Header.Set("Mcp-Session-Id", clientSID) + req.Header.Set("Content-Type", "application/json") + + resp, err := r.reinitializeAndReplay(req, origBody) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() + + // Verify session was updated with new backend SID and a pod URL. + updated, ok := store.Get(clientSID) + require.True(t, ok) + backendSID, exists := updated.GetMetadataValue(sessionMetadataBackendSID) + require.True(t, exists) + assert.Equal(t, newBackendSID, backendSID) + + backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL) + require.True(t, exists) + assert.NotEmpty(t, backendURL) + + // Two forward calls: initialize + replay. The initialize must not carry + // a session ID; the replay must carry the new backend SID. + forwardMu.Lock() + defer forwardMu.Unlock() + require.Len(t, forwardCalls, 2, "forward should be called for initialize and replay") + assert.Empty(t, forwardCalls[0], "initialize request must not carry Mcp-Session-Id") + assert.Equal(t, newBackendSID, forwardCalls[1], "replay must carry the new backend SID") +} + +// TestBackendRecoveryReinitForwardError verifies that a forward error during +// re-initialization is returned to the caller. +func TestBackendRecoveryReinitForwardError(t *testing.T) { + t.Parallel() + + // Server that is immediately closed — all connections will be refused. + dead := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + deadURL := dead.URL + dead.Close() + + clientSID := uuid.New().String() + sess := session.NewProxySession(clientSID) + sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) + store := newStubStore(sess) + + r := newRecovery(deadURL, store, http.DefaultTransport.RoundTrip) + + req, err := http.NewRequest(http.MethodPost, deadURL+"/mcp", + strings.NewReader(`{"method":"tools/list"}`)) + require.NoError(t, err) + req.Header.Set("Mcp-Session-Id", clientSID) + + resp, err := r.reinitializeAndReplay(req, []byte(`{"method":"tools/list"}`)) + assert.Nil(t, resp) + assert.Error(t, err, "forward error during re-init should be returned") +} + +// TestBackendRecoveryNoNewSessionID verifies that when the re-initialize +// response carries no Mcp-Session-Id, reinitializeAndReplay resets backend_url +// to ClusterIP and returns (nil, nil). +func TestBackendRecoveryNoNewSessionID(t *testing.T) { + t.Parallel() + + // Backend that returns no Mcp-Session-Id on initialize. + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) // no Mcp-Session-Id header + })) + defer backend.Close() + + clientSID := uuid.New().String() + sess := session.NewProxySession(clientSID) + sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) + sess.SetMetadata(sessionMetadataBackendURL, "http://10.0.0.5:8080") + store := newStubStore(sess) + + // targetURI points to backend (so the init request succeeds), but we verify + // that backend_url is reset to targetURI when no session ID comes back. + r := newRecovery(backend.URL, store, http.DefaultTransport.RoundTrip) + + req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp", + strings.NewReader(`{"method":"tools/list"}`)) + require.NoError(t, err) + req.Header.Set("Mcp-Session-Id", clientSID) + + resp, err := r.reinitializeAndReplay(req, []byte(`{"method":"tools/list"}`)) + assert.Nil(t, resp) + assert.NoError(t, err) + + updated, ok := store.Get(clientSID) + require.True(t, ok) + backendURL, exists := updated.GetMetadataValue(sessionMetadataBackendURL) + require.True(t, exists) + assert.Equal(t, backend.URL, backendURL, "backend_url should fall back to targetURI when no new session ID") +} + +// TestPodBackendURLWithCapturedAddr verifies that a captured pod IP replaces the +// host in targetURI while preserving the scheme. +func TestPodBackendURLWithCapturedAddr(t *testing.T) { + t.Parallel() + + r := &backendRecovery{targetURI: "http://cluster-ip:8080"} + got := r.podBackendURL("10.0.0.5:8080") + assert.Equal(t, "http://10.0.0.5:8080", got) +} + +// TestPodBackendURLFallback verifies that an empty captured address falls back +// to targetURI unchanged. +func TestPodBackendURLFallback(t *testing.T) { + t.Parallel() + + r := &backendRecovery{targetURI: "http://cluster-ip:8080"} + got := r.podBackendURL("") + assert.Equal(t, "http://cluster-ip:8080", got) +} + +// TestPodBackendURLHTTPSFallback verifies that an HTTPS targetURI is never +// rewritten to a pod IP. IP-literal HTTPS URLs fail TLS verification because +// server certificates are issued for hostnames, not pod IPs. +func TestPodBackendURLHTTPSFallback(t *testing.T) { + t.Parallel() + + r := &backendRecovery{targetURI: "https://mcp.example.com/mcp"} + got := r.podBackendURL("1.2.3.4:443") + assert.Equal(t, "https://mcp.example.com/mcp", got, + "HTTPS target must not be rewritten to a pod IP") +} diff --git a/pkg/transport/proxy/transparent/backend_routing_test.go b/pkg/transport/proxy/transparent/backend_routing_test.go index 4167c86406..0d3319dc2c 100644 --- a/pkg/transport/proxy/transparent/backend_routing_test.go +++ b/pkg/transport/proxy/transparent/backend_routing_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "sync/atomic" "testing" "time" @@ -171,13 +172,13 @@ func TestRoundTripReturns404ForUnknownSession(t *testing.T) { })) defer backend.Close() - tt := &tracingTransport{base: http.DefaultTransport, p: NewTransparentProxyWithOptions( + tt := newTracingTransport(http.DefaultTransport, NewTransparentProxyWithOptions( "localhost", 0, backend.URL, nil, nil, nil, false, false, "sse", nil, nil, "", false, nil, - )} + )) req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp", strings.NewReader(`{"method":"tools/list"}`)) @@ -206,13 +207,13 @@ func TestRoundTripAllowsInitializeWithUnknownSession(t *testing.T) { })) defer backend.Close() - tt := &tracingTransport{base: http.DefaultTransport, p: NewTransparentProxyWithOptions( + tt := newTracingTransport(http.DefaultTransport, NewTransparentProxyWithOptions( "localhost", 0, backend.URL, nil, nil, nil, false, false, "sse", nil, nil, "", false, nil, - )} + )) req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp", strings.NewReader(`{"method":"initialize"}`)) @@ -238,13 +239,13 @@ func TestRoundTripAllowsBatchInitializeWithUnknownSession(t *testing.T) { })) defer backend.Close() - tt := &tracingTransport{base: http.DefaultTransport, p: NewTransparentProxyWithOptions( + tt := newTracingTransport(http.DefaultTransport, NewTransparentProxyWithOptions( "localhost", 0, backend.URL, nil, nil, nil, false, false, "sse", nil, nil, "", false, nil, - )} + )) req, err := http.NewRequest(http.MethodPost, backend.URL+"/mcp", strings.NewReader(`[{"method":"initialize"},{"method":"tools/list"}]`)) @@ -289,3 +290,287 @@ func TestRoundTripStoresBackendURLOnInitialize(t *testing.T) { require.True(t, ok, "session should have backend_url metadata") assert.Equal(t, backend.URL, backendURL) } + +// TestRoundTripStoresInitBodyOnInitialize verifies that the raw JSON-RPC initialize +// request body is stored in session metadata so the proxy can transparently +// re-initialize the backend session if the pod is later replaced. +func TestRoundTripStoresInitBodyOnInitialize(t *testing.T) { + t.Parallel() + + sessionID := uuid.New().String() + const initBody = `{"jsonrpc":"2.0","id":1,"method":"initialize"}` + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Mcp-Session-Id", sessionID) + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + proxy, addr := startProxy(t, backend.URL) + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + "http://"+addr+"/mcp", + strings.NewReader(initBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _ = resp.Body.Close() + + sess, ok := proxy.sessionManager.Get(normalizeSessionID(sessionID)) + require.True(t, ok, "session should have been created") + stored, exists := sess.GetMetadataValue(sessionMetadataInitBody) + require.True(t, exists, "init_body should be stored in session metadata") + assert.Equal(t, initBody, stored) +} + +// TestRoundTripReinitializesOnBackend404 verifies that when the backend pod returns +// 404 (session lost after restart on the same IP), the proxy transparently +// re-initializes the backend session and replays the original request — client sees 200. +func TestRoundTripReinitializesOnBackend404(t *testing.T) { + t.Parallel() + + // staleBackend simulates a pod that has lost its in-memory session state. + var staleHit atomic.Int32 + staleBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + staleHit.Add(1) + w.WriteHeader(http.StatusNotFound) + })) + defer staleBackend.Close() + + // freshBackend simulates a healthy pod: returns a session ID on initialize + // and 200 for all other requests. + freshSessionID := uuid.New().String() + var freshHit atomic.Int32 + freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + freshHit.Add(1) + body, _ := io.ReadAll(r.Body) + if strings.Contains(string(body), `"initialize"`) { + w.Header().Set("Mcp-Session-Id", freshSessionID) + } + w.WriteHeader(http.StatusOK) + })) + defer freshBackend.Close() + + // targetURI (ClusterIP) points to freshBackend; the session has staleBackend as backend_url. + proxy, addr := startProxy(t, freshBackend.URL) + + clientSessionID := uuid.New().String() + sess := session.NewProxySession(clientSessionID) + sess.SetMetadata(sessionMetadataBackendURL, staleBackend.URL) + sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) + require.NoError(t, proxy.sessionManager.AddSession(sess)) + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + "http://"+addr+"/mcp", + strings.NewReader(`{"method":"tools/list"}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Mcp-Session-Id", clientSessionID) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _ = resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode, "client should see 200 after transparent re-init") + assert.GreaterOrEqual(t, staleHit.Load(), int32(1), "stale backend should have been hit") + assert.GreaterOrEqual(t, freshHit.Load(), int32(2), "fresh backend should receive initialize + replay") + + // Session should now have backend_sid mapping to the new backend session. + updated, ok := proxy.sessionManager.Get(normalizeSessionID(clientSessionID)) + require.True(t, ok, "session should still exist after re-init") + backendSID, exists := updated.GetMetadataValue(sessionMetadataBackendSID) + require.True(t, exists, "backend_sid should be set after re-init") + assert.Equal(t, freshSessionID, backendSID, "backend_sid must be the raw value the backend issued, not normalized") +} + +// TestRoundTripReinitializesPreservesNonUUIDBackendSessionID verifies that when the +// backend issues a non-UUID Mcp-Session-Id on re-initialization, the proxy stores +// and forwards the raw value — not a UUID v5 hash of it — on all subsequent requests. +// +// The normalization bug only manifests on the request AFTER the replay: the replay +// sets Mcp-Session-Id directly from newBackendSID (bypassing Rewrite), but subsequent +// requests go through the Rewrite closure which reads backend_sid from session metadata. +// If backend_sid was stored as normalizeSessionID(newBackendSID), Rewrite would send +// the wrong (hashed) value and the backend would reject every subsequent request. +func TestRoundTripReinitializesPreservesNonUUIDBackendSessionID(t *testing.T) { + t.Parallel() + + // Non-UUID opaque token, as some MCP servers issue. + const nonUUIDSessionID = "opaque-session-token-abc123" + + staleBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer staleBackend.Close() + + // receivedSIDs tracks Mcp-Session-Id values arriving on non-initialize requests, + // in order. Index 0 = replay (direct from reinitializeAndReplay), index 1 = second + // client request (routed through Rewrite reading backend_sid from session metadata). + var ( + receivedMu sync.Mutex + receivedSIDs []string + ) + freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + if strings.Contains(string(body), `"initialize"`) { + w.Header().Set("Mcp-Session-Id", nonUUIDSessionID) + w.WriteHeader(http.StatusOK) + return + } + receivedMu.Lock() + receivedSIDs = append(receivedSIDs, r.Header.Get("Mcp-Session-Id")) + receivedMu.Unlock() + w.WriteHeader(http.StatusOK) + })) + defer freshBackend.Close() + + proxy, addr := startProxy(t, freshBackend.URL) + + clientSessionID := uuid.New().String() + sess := session.NewProxySession(clientSessionID) + sess.SetMetadata(sessionMetadataBackendURL, staleBackend.URL) + sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) + require.NoError(t, proxy.sessionManager.AddSession(sess)) + + doRequest := func() *http.Response { + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + "http://"+addr+"/mcp", + strings.NewReader(`{"method":"tools/list"}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Mcp-Session-Id", clientSessionID) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + return resp + } + + // First request: triggers re-init. The replay (inside reinitializeAndReplay) sets + // Mcp-Session-Id directly, so receivedSIDs[0] is always the raw value regardless + // of what is stored in session metadata. + resp1 := doRequest() + _ = resp1.Body.Close() + require.Equal(t, http.StatusOK, resp1.StatusCode) + + // Second request: goes through the Rewrite closure, which reads backend_sid from + // session metadata. This is where the normalization bug manifests — if backend_sid + // was stored as normalizeSessionID(nonUUIDSessionID), Rewrite would forward the + // wrong hashed value and receivedSIDs[1] would not equal nonUUIDSessionID. + resp2 := doRequest() + _ = resp2.Body.Close() + require.Equal(t, http.StatusOK, resp2.StatusCode) + + receivedMu.Lock() + defer receivedMu.Unlock() + require.Len(t, receivedSIDs, 2, "fresh backend should have received replay + second request") + assert.Equal(t, nonUUIDSessionID, receivedSIDs[0], "replay must forward raw non-UUID session ID") + assert.Equal(t, nonUUIDSessionID, receivedSIDs[1], "subsequent request via Rewrite must forward raw non-UUID session ID") +} + +// TestRoundTripReinitializesAfterPriorReinit verifies that re-initialization +// triggers correctly on a second failure when the session already has a +// backend_sid from a prior re-init. Without the clientSID capture fix, +// RoundTrip rewrites the header to backend_sid before calling reinitializeAndReplay, +// which then looks up the session by the (wrong) backend SID and finds nothing. +func TestRoundTripReinitializesAfterPriorReinit(t *testing.T) { + t.Parallel() + + firstBackendSID := uuid.New().String() + secondBackendSID := uuid.New().String() + + // staleBackend: returns 404 to trigger re-init. + staleBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer staleBackend.Close() + + var freshHit atomic.Int32 + freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + freshHit.Add(1) + body, _ := io.ReadAll(r.Body) + if strings.Contains(string(body), `"initialize"`) { + w.Header().Set("Mcp-Session-Id", secondBackendSID) + } + w.WriteHeader(http.StatusOK) + })) + defer freshBackend.Close() + + proxy, addr := startProxy(t, freshBackend.URL) + + // Session pre-populated as if a prior re-init already happened: + // backend_url points to staleBackend, backend_sid is set to firstBackendSID. + clientSessionID := uuid.New().String() + sess := session.NewProxySession(clientSessionID) + sess.SetMetadata(sessionMetadataBackendURL, staleBackend.URL) + sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) + sess.SetMetadata(sessionMetadataBackendSID, firstBackendSID) + require.NoError(t, proxy.sessionManager.AddSession(sess)) + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + "http://"+addr+"/mcp", + strings.NewReader(`{"method":"tools/list"}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Mcp-Session-Id", clientSessionID) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _ = resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode, + "client should see 200: re-init must use client SID for session lookup, not backend SID") + assert.GreaterOrEqual(t, freshHit.Load(), int32(2), + "fresh backend should receive re-initialize + replay") +} + +// TestRoundTripReinitializesOnDialError verifies that when the proxy cannot reach +// the stored pod IP (dial error — pod rescheduled to a new IP), it transparently +// re-initializes the backend session via the ClusterIP and replays the original +// request — the client sees a 200. +func TestRoundTripReinitializesOnDialError(t *testing.T) { + t.Parallel() + + // Create a server and immediately close it so its URL refuses connections. + dead := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})) + deadURL := dead.URL + dead.Close() + + freshSessionID := uuid.New().String() + var freshHit atomic.Int32 + freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + freshHit.Add(1) + body, _ := io.ReadAll(r.Body) + if strings.Contains(string(body), `"initialize"`) { + w.Header().Set("Mcp-Session-Id", freshSessionID) + } + w.WriteHeader(http.StatusOK) + })) + defer freshBackend.Close() + + proxy, addr := startProxy(t, freshBackend.URL) + + clientSessionID := uuid.New().String() + sess := session.NewProxySession(clientSessionID) + sess.SetMetadata(sessionMetadataBackendURL, deadURL) + sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`) + require.NoError(t, proxy.sessionManager.AddSession(sess)) + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + "http://"+addr+"/mcp", + strings.NewReader(`{"method":"tools/list"}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Mcp-Session-Id", clientSessionID) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + _ = resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode, "client should see 200 after transparent re-init on dial error") + assert.GreaterOrEqual(t, freshHit.Load(), int32(2), "fresh backend should receive initialize + replay") +} diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index 408abd7f08..f60f32798c 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -15,6 +15,7 @@ import ( "log/slog" "net" "net/http" + "net/http/httptrace" "net/http/httputil" "net/url" "os" @@ -153,6 +154,17 @@ const ( // It is written on initialize and read in the Rewrite closure to route follow-up requests // to the same backend pod that handled the session's initialize request. sessionMetadataBackendURL = "backend_url" + + // sessionMetadataInitBody stores the raw JSON-RPC initialize request body. + // It is used to transparently re-initialize a backend session when the pod that + // originally handled initialize has been replaced (new IP or lost in-memory state). + sessionMetadataInitBody = "init_body" + + // sessionMetadataBackendSID stores the backend's assigned Mcp-Session-Id when it + // diverges from the client-facing session ID after a transparent re-initialization. + // tracingTransport.RoundTrip rewrites the outbound Mcp-Session-Id header to this + // value so the backend sees its own session ID while the client keeps its original one. + sessionMetadataBackendSID = "backend_sid" ) // Option is a functional option for configuring TransparentProxy @@ -360,9 +372,36 @@ func NewTransparentProxyWithOptions( return proxy } +// recoverySessionStore is the subset of session.Manager that backendRecovery needs. +type recoverySessionStore interface { + Get(id string) (session.Session, bool) + UpsertSession(sess session.Session) error +} + +// backendRecovery handles transparent re-initialization of backend sessions when the +// target pod is unreachable (dial error) or has lost its in-memory session state (404). +// It depends only on a narrow session interface and a forward function, so it can be +// tested without standing up a full proxy. +type backendRecovery struct { + targetURI string + forward func(*http.Request) (*http.Response, error) + sessions recoverySessionStore +} + type tracingTransport struct { - base http.RoundTripper - p *TransparentProxy + p *TransparentProxy + recovery *backendRecovery +} + +func newTracingTransport(base http.RoundTripper, p *TransparentProxy) *tracingTransport { + return &tracingTransport{ + p: p, + recovery: &backendRecovery{ + targetURI: p.targetURI, + forward: base.RoundTrip, + sessions: p.sessionManager, + }, + } } func (p *TransparentProxy) setServerInitialized() { @@ -377,14 +416,6 @@ func (p *TransparentProxy) serverInitialized() bool { return p.isServerInitialized.Load() } -func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) { - tr := t.base - if tr == nil { - tr = http.DefaultTransport - } - return tr.RoundTrip(req) -} - // nolint:gocyclo // This function handles multiple request types and is complex by design func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) { // Always rewrite Host header to match the target URL to avoid "Invalid Host" errors @@ -436,12 +467,51 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) } } - resp, err := t.forward(req) + // Capture the client-facing session ID before the backend SID rewrite below. + // Recovery and session cleanup paths must look up sessions by the client SID + // (the store key), not the backend SID that is written into the header. + clientSID := req.Header.Get("Mcp-Session-Id") + + // Rewrite the outbound Mcp-Session-Id to the backend's assigned session ID when + // the proxy transparently re-initialized the backend session. This is done here + // (after the guard check above) so the guard always sees the original client + // session ID and can look it up correctly in the session store. + if clientSID != "" { + if sess, ok := t.p.sessionManager.Get(normalizeSessionID(clientSID)); ok { + if backendSID, exists := sess.GetMetadataValue(sessionMetadataBackendSID); exists && backendSID != "" { + req.Header.Set("Mcp-Session-Id", backendSID) + } + } + } + + // Attach an httptrace to capture the actual backend pod IP after kube-proxy + // DNAT resolves the ClusterIP to a specific pod. The captured address is stored + // as backend_url so follow-up requests always reach the same pod, even after a + // proxy runner restart that would otherwise lose the in-memory routing state. + var capturedPodAddr string + if sawInitialize { + trace := &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + capturedPodAddr = info.Conn.RemoteAddr().String() + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + } + + resp, err := t.recovery.forward(req) if err != nil { if errors.Is(err, context.Canceled) { // Expected during shutdown or client disconnect—silently ignore return nil, err } + // Dial error against a stored pod IP means the pod has been replaced. + // Attempt transparent re-initialization so the client sees no error. + if isDialError(err) { + req.Header.Set("Mcp-Session-Id", clientSID) + if reInitResp, reInitErr := t.recovery.reinitializeAndReplay(req, reqBody); reInitResp != nil || reInitErr != nil { + return reInitResp, reInitErr + } + } slog.Error("failed to forward request", "error", err) return nil, err } @@ -463,10 +533,26 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) // reference would only waste memory. if req.Method == http.MethodDelete && (resp.StatusCode >= 200 && resp.StatusCode < 300 || resp.StatusCode == http.StatusNotFound) { - if sid := req.Header.Get("Mcp-Session-Id"); sid != "" { - if err := t.p.sessionManager.Delete(normalizeSessionID(sid)); err != nil { + if clientSID != "" { + if err := t.p.sessionManager.Delete(normalizeSessionID(clientSID)); err != nil { slog.Debug("failed to delete session from transparent proxy", - "session_id", sid, "error", err) + "session_id", clientSID, "error", err) + } + } + } + + // Backend returned 404 for a non-initialize, non-DELETE request whose session IS + // known to the proxy. This means the backend pod lost its in-memory session state + // (e.g. it was restarted but got the same IP). Attempt transparent re-initialization + // so the client sees no error. DELETE is excluded because the session has already + // been cleaned up above and the 404 is the expected terminal response. + if resp.StatusCode == http.StatusNotFound && !sawInitialize && req.Method != http.MethodDelete { + if clientSID != "" { + req.Header.Set("Mcp-Session-Id", clientSID) + if reInitResp, reInitErr := t.recovery.reinitializeAndReplay(req, reqBody); reInitResp != nil || reInitErr != nil { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + return reInitResp, reInitErr } } } @@ -480,14 +566,15 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) internalID := normalizeSessionID(ct) if _, ok := t.p.sessionManager.Get(internalID); !ok { sess := session.NewProxySession(internalID) - // Store targetURI as the default backend_url for this session. - // In single-replica deployments targetURI is already the pod address, - // so no override is needed. In multi-replica deployments the - // vMCP/operator layer is responsible for setting backend_url to the - // actual pod DNS name (e.g. http://mcp-server-0.mcp-server.default.svc:8080) - // before the request reaches this proxy; the Rewrite closure then reads - // that value and routes follow-up requests to the correct pod. - sess.SetMetadata(sessionMetadataBackendURL, t.p.targetURI) + // Store the actual pod IP (captured via GotConn) as backend_url so that + // after a proxy runner restart the session is routed to the same backend + // pod that handled initialize, not a random pod via ClusterIP. + sess.SetMetadata(sessionMetadataBackendURL, t.recovery.podBackendURL(capturedPodAddr)) + // Store the initialize body so we can transparently re-initialize the + // backend session if the pod is later replaced or loses session state. + if len(reqBody) > 0 { + sess.SetMetadata(sessionMetadataInitBody, string(reqBody)) + } if err := t.p.sessionManager.AddSession(sess); err != nil { //nolint:gosec // G706: session ID from HTTP response header slog.Error("failed to create session from header", @@ -553,6 +640,157 @@ func (t *tracingTransport) detectInitialize(body []byte) bool { return false } +// podBackendURL constructs a backend URL that targets the specific pod IP captured +// via httptrace.GotConn, using the scheme from targetURI. Falls back to targetURI +// when no address was captured (e.g. single-replica, connection reuse without a new conn), +// or when targetURI uses HTTPS — IP-literal HTTPS URLs fail TLS verification because +// server certificates are issued for hostnames, not pod IPs. +func (r *backendRecovery) podBackendURL(capturedAddr string) string { + if capturedAddr == "" { + return r.targetURI + } + parsed, err := url.Parse(r.targetURI) + if err != nil { + return r.targetURI + } + if parsed.Scheme == "https" { + return r.targetURI + } + parsed.Host = capturedAddr + return parsed.String() +} + +// isDialError reports whether err is a TCP dial failure, indicating that the +// target host is unreachable (pod has been terminated or rescheduled). +func isDialError(err error) bool { + var opErr *net.OpError + return errors.As(err, &opErr) && opErr.Op == "dial" +} + +// reinitializeAndReplay is called when the proxy detects that the backend pod +// that owned a session is no longer reachable (dial error) or has lost its +// in-memory session state (backend returned 404). It transparently: +// 1. Re-sends the stored initialize body to the ClusterIP service so kube-proxy +// selects a healthy pod and the backend creates a new session. +// 2. Captures the new pod IP via httptrace.GotConn and stores it as backend_url. +// 3. Maps the client's original session ID to the new backend session ID. +// 4. Replays the original client request so the client sees no error. +// +// Returns (nil, nil) when re-initialization is not applicable (session unknown +// to the proxy, or no stored init body for the session). +func (r *backendRecovery) reinitializeAndReplay(req *http.Request, origBody []byte) (*http.Response, error) { + sid := req.Header.Get("Mcp-Session-Id") + if sid == "" { + return nil, nil + } + internalSID := normalizeSessionID(sid) + sess, ok := r.sessions.Get(internalSID) + if !ok { + return nil, nil + } + + initBody, hasInit := sess.GetMetadataValue(sessionMetadataInitBody) + if !hasInit || initBody == "" { + // No stored init body — cannot re-initialize transparently. + // Reset backend_url to ClusterIP so the next request goes through + // kube-proxy and lets the client receive a clean 404 to re-initialize. + sess.SetMetadata(sessionMetadataBackendURL, r.targetURI) + _ = r.sessions.UpsertSession(sess) + return nil, nil + } + + slog.Debug("backend session lost; transparently re-initializing", + "session_id", sid, "target", r.targetURI) + + // Capture the new pod IP via GotConn on the re-initialize connection. + var capturedPodAddr string + trace := &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + capturedPodAddr = info.Conn.RemoteAddr().String() + }, + } + initCtx := httptrace.WithClientTrace(req.Context(), trace) + + // Build a fresh initialize request to the ClusterIP (no Mcp-Session-Id — + // the backend assigns a new session ID in the response). + parsedTarget, err := url.Parse(r.targetURI) + if err != nil { + return nil, nil + } + initURL := *req.URL + initURL.Scheme = parsedTarget.Scheme + initURL.Host = parsedTarget.Host + + initReq, err := http.NewRequestWithContext(initCtx, http.MethodPost, initURL.String(), bytes.NewReader([]byte(initBody))) + if err != nil { + return nil, nil + } + // Propagate headers from the original request (Authorization, tenant headers, etc.) + // so the backend accepts the re-initialize. Mcp-Session-Id must not be forwarded — + // the backend assigns a new session ID in the response. Content-Length and + // Transfer-Encoding are deleted because http.NewRequestWithContext already set + // ContentLength from the body; leaving stale header values would be misleading + // (Go's transport ignores them in favour of the struct field, but clarity matters). + initReq.Header = req.Header.Clone() + initReq.Header.Del("Mcp-Session-Id") + initReq.Header.Del("Content-Length") + initReq.Header.Del("Transfer-Encoding") + initReq.Header.Set("Content-Type", "application/json") + + initResp, err := r.forward(initReq) + if err != nil { + slog.Error("transparent re-initialize failed", "error", err) + return nil, err + } + _, _ = io.Copy(io.Discard, initResp.Body) + _ = initResp.Body.Close() + + newBackendSID := initResp.Header.Get("Mcp-Session-Id") + if newBackendSID == "" { + slog.Debug("re-initialize response contained no Mcp-Session-Id; falling back to ClusterIP") + sess.SetMetadata(sessionMetadataBackendURL, r.targetURI) + _ = r.sessions.UpsertSession(sess) + return nil, nil + } + + // Update session: point backend_url at the newly-discovered pod and record + // the backend session ID so tracingTransport.RoundTrip rewrites Mcp-Session-Id on outbound requests. + newPodURL := r.podBackendURL(capturedPodAddr) + sess.SetMetadata(sessionMetadataBackendURL, newPodURL) + // Store the raw backend session ID (not normalized) because the Rewrite closure + // uses this value verbatim as the outbound Mcp-Session-Id header. Normalizing + // would change non-UUID IDs to a UUID v5 hash the backend never issued. + sess.SetMetadata(sessionMetadataBackendSID, newBackendSID) + if upsertErr := r.sessions.UpsertSession(sess); upsertErr != nil { + slog.Debug("failed to update session after re-initialize", "error", upsertErr) + } + + // Replay the original client request to the new pod with the new backend SID. + // Use the captured pod address directly so we bypass the Rewrite closure + // (which still holds the old backend_url until the next session load). + // For HTTPS targets, keep the original hostname: IP-literal HTTPS requests + // fail TLS verification because server certs are issued for hostnames, not pod IPs. + replayHost := capturedPodAddr + if replayHost == "" || parsedTarget.Scheme == "https" { + replayHost = parsedTarget.Host + } + replayReq := req.Clone(req.Context()) + replayReq.URL.Scheme = parsedTarget.Scheme + replayReq.URL.Host = replayHost + replayReq.Host = replayHost // keep Host header consistent with URL to avoid backend validation errors + replayReq.Header.Set("Mcp-Session-Id", newBackendSID) + replayReq.Body = io.NopCloser(bytes.NewReader(origBody)) + replayReq.ContentLength = int64(len(origBody)) + // origBody is fully buffered, so chunked encoding is unnecessary and would + // suppress the Content-Length header. Clear any TransferEncoding copied from + // the original request so net/http sends Content-Length instead. + replayReq.TransferEncoding = nil + + slog.Debug("replaying original request after transparent re-initialization", + "new_pod_url", newPodURL, "new_backend_sid", newBackendSID) + return r.forward(replayReq) +} + // modifyResponse modifies HTTP responses based on transport-specific requirements. // Delegates to the appropriate ResponseProcessor based on transport type. func (p *TransparentProxy) modifyResponse(resp *http.Response) error { @@ -643,7 +881,7 @@ func (p *TransparentProxy) Start(ctx context.Context) error { }, } - proxy.Transport = &tracingTransport{base: http.DefaultTransport, p: p} + proxy.Transport = newTracingTransport(http.DefaultTransport, p) proxy.ModifyResponse = func(resp *http.Response) error { return p.modifyResponse(resp) } diff --git a/pkg/transport/proxy/transparent/transparent_test.go b/pkg/transport/proxy/transparent/transparent_test.go index 4461de8524..cf5923bb6c 100644 --- a/pkg/transport/proxy/transparent/transparent_test.go +++ b/pkg/transport/proxy/transparent/transparent_test.go @@ -48,7 +48,7 @@ func TestStreamingSessionIDDetection(t *testing.T) { parsedURL, _ := http.NewRequest("GET", target.URL, nil) proxyURL := httputil.NewSingleHostReverseProxy(parsedURL.URL) proxyURL.FlushInterval = -1 - proxyURL.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy} + proxyURL.Transport = newTracingTransport(http.DefaultTransport, proxy) proxyURL.ModifyResponse = proxy.modifyResponse // hit the proxy @@ -77,7 +77,7 @@ func createBasicProxy(p *TransparentProxy, targetURL *url.URL) *httputil.Reverse pr.SetXForwarded() }, FlushInterval: -1, - Transport: &tracingTransport{base: http.DefaultTransport, p: p}, + Transport: newTracingTransport(http.DefaultTransport, p), ModifyResponse: p.modifyResponse, } return proxy @@ -172,7 +172,7 @@ func TestTracePropagationHeaders(t *testing.T) { }, } - reverseProxy.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy} + reverseProxy.Transport = newTracingTransport(http.DefaultTransport, proxy) // Create request with trace context ctx, span := otel.Tracer("test").Start(context.Background(), "test-operation") @@ -428,7 +428,7 @@ func TestTransparentProxy_UnauthorizedResponseCallback(t *testing.T) { // Create reverse proxy with tracing transport reverseProxy := httputil.NewSingleHostReverseProxy(targetURL) reverseProxy.FlushInterval = -1 - tracingTrans := &tracingTransport{base: http.DefaultTransport, p: proxy} + tracingTrans := newTracingTransport(http.DefaultTransport, proxy) reverseProxy.Transport = tracingTrans // Make a request through the proxy @@ -474,7 +474,7 @@ func TestTransparentProxy_UnauthorizedResponseCallback_Multiple401s(t *testing.T // Create reverse proxy with tracing transport reverseProxy := httputil.NewSingleHostReverseProxy(targetURL) reverseProxy.FlushInterval = -1 - reverseProxy.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy} + reverseProxy.Transport = newTracingTransport(http.DefaultTransport, proxy) // Make multiple requests through the proxy for i := 0; i < 5; i++ { @@ -519,7 +519,7 @@ func TestTransparentProxy_NoUnauthorizedCallbackOnSuccess(t *testing.T) { // Create reverse proxy with tracing transport reverseProxy := httputil.NewSingleHostReverseProxy(targetURL) reverseProxy.FlushInterval = -1 - reverseProxy.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy} + reverseProxy.Transport = newTracingTransport(http.DefaultTransport, proxy) // Make a request through the proxy rec := httptest.NewRecorder() @@ -556,7 +556,7 @@ func TestTransparentProxy_NilUnauthorizedCallback(t *testing.T) { // Create reverse proxy with tracing transport reverseProxy := httputil.NewSingleHostReverseProxy(targetURL) reverseProxy.FlushInterval = -1 - reverseProxy.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy} + reverseProxy.Transport = newTracingTransport(http.DefaultTransport, proxy) // Make a request through the proxy - should not panic rec := httptest.NewRecorder() @@ -848,7 +848,7 @@ func TestSSEEndpointRewriting(t *testing.T) { parsedURL, _ := http.NewRequest("GET", target.URL, nil) proxyURL := httputil.NewSingleHostReverseProxy(parsedURL.URL) proxyURL.FlushInterval = -1 - proxyURL.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy} + proxyURL.Transport = newTracingTransport(http.DefaultTransport, proxy) proxyURL.ModifyResponse = proxy.modifyResponse // Create request with X-Forwarded-Prefix header @@ -897,7 +897,7 @@ func TestSSEEndpointRewritingWithExplicitPrefix(t *testing.T) { parsedURL, _ := http.NewRequest("GET", target.URL, nil) proxyURL := httputil.NewSingleHostReverseProxy(parsedURL.URL) proxyURL.FlushInterval = -1 - proxyURL.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy} + proxyURL.Transport = newTracingTransport(http.DefaultTransport, proxy) proxyURL.ModifyResponse = proxy.modifyResponse rec := httptest.NewRecorder() @@ -945,7 +945,7 @@ func TestSSEMessageEventNotRewritten(t *testing.T) { parsedURL, _ := http.NewRequest("GET", target.URL, nil) proxyURL := httputil.NewSingleHostReverseProxy(parsedURL.URL) proxyURL.FlushInterval = -1 - proxyURL.Transport = &tracingTransport{base: http.DefaultTransport, p: proxy} + proxyURL.Transport = newTracingTransport(http.DefaultTransport, proxy) proxyURL.ModifyResponse = proxy.modifyResponse rec := httptest.NewRecorder() diff --git a/test/e2e/thv-operator/virtualmcp/mcpserver_scaling_test.go b/test/e2e/thv-operator/virtualmcp/mcpserver_scaling_test.go index 5eba000a64..c511155a8c 100644 --- a/test/e2e/thv-operator/virtualmcp/mcpserver_scaling_test.go +++ b/test/e2e/thv-operator/virtualmcp/mcpserver_scaling_test.go @@ -26,6 +26,7 @@ import ( mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" "github.com/stacklok/toolhive/test/e2e/images" + "github.com/stacklok/toolhive/test/e2e/thv-operator/testutil" ) // deployRedis creates a single-replica Redis Deployment and ClusterIP Service. @@ -103,6 +104,8 @@ func cleanupRedis(namespace, name string) { } // getReadyMCPServerPods returns all Running+Ready pods for an MCPServer. +// +//nolint:unparam // namespace kept as parameter for reusability across test contexts func getReadyMCPServerPods(mcpServerName, namespace string) ([]corev1.Pod, error) { podList := &corev1.PodList{} if err := k8sClient.List(ctx, podList, @@ -172,7 +175,7 @@ func portForwardToPod(podName, namespace string, targetPort int32) (int, func(), } // Wait for the port-forward to be ready - for i := 0; i < 30; i++ { + for range 30 { conn, dialErr := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", localPort), 500*time.Millisecond) if dialErr == nil { _ = conn.Close() @@ -193,6 +196,172 @@ var _ = ginkgo.Describe("MCPServer Cross-Replica Session Routing with Redis", fu proxyPort = int32(8080) ) + ginkgo.Context("When MCPServer has backendReplicas=2 and proxy runner restarts", ginkgo.Ordered, func() { + var ( + mcpServerName string + redisName string + nodePortName string + nodePort int32 + ) + + ginkgo.BeforeAll(func() { + ts := time.Now().UnixNano() + mcpServerName = fmt.Sprintf("e2e-backend-scale-%d", ts) + redisName = fmt.Sprintf("e2e-redis-be-%d", ts) + nodePortName = mcpServerName + "-nodeport" + + ginkgo.By("Deploying Redis for session storage") + deployRedis(defaultNamespace, redisName, timeout, pollInterval) + + replicas := int32(1) + backendReplicas := int32(2) + redisAddr := fmt.Sprintf("%s.%s.svc.cluster.local:6379", redisName, defaultNamespace) + + ginkgo.By("Creating MCPServer with replicas=1, backendReplicas=2, Redis session storage") + gomega.Expect(k8sClient.Create(ctx, &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: mcpServerName, Namespace: defaultNamespace}, + Spec: mcpv1alpha1.MCPServerSpec{ + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: proxyPort, + McpPort: 8080, + Replicas: &replicas, + BackendReplicas: &backendReplicas, + SessionAffinity: "None", + SessionStorage: &mcpv1alpha1.SessionStorageConfig{ + Provider: mcpv1alpha1.SessionStorageProviderRedis, + Address: redisAddr, + }, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for MCPServer to be Running") + waitForMCPServerRunning(mcpServerName, defaultNamespace, timeout, pollInterval) + + ginkgo.By("Waiting for 1 ready proxy runner pod") + gomega.Eventually(func() (int, error) { + pods, err := getReadyMCPServerPods(mcpServerName, defaultNamespace) + if err != nil { + return 0, err + } + return len(pods), nil + }, timeout, pollInterval).Should(gomega.Equal(1)) + + ginkgo.By("Creating a NodePort service for external access to the proxy runner") + testutil.CreateNodePortService(ctx, k8sClient, mcpServerName, defaultNamespace) + + ginkgo.By("Waiting for NodePort to be assigned") + nodePort = testutil.GetNodePort(ctx, k8sClient, nodePortName, defaultNamespace, timeout, pollInterval) + + ginkgo.By("Waiting for NodePort to be accessible and serving HTTP health") + gomega.Eventually(func() error { + if err := checkPortAccessible(nodePort, 1*time.Second); err != nil { + return fmt.Errorf("nodePort %d not accessible: %w", nodePort, err) + } + if err := checkHTTPHealthReady(nodePort, 2*time.Second); err != nil { + return fmt.Errorf("nodePort %d not ready: %w", nodePort, err) + } + return nil + }, timeout, pollInterval).Should(gomega.Succeed(), "NodePort should be accessible and ready") + }) + + ginkgo.AfterAll(func() { + _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: mcpServerName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: nodePortName, Namespace: defaultNamespace}, + }) + cleanupRedis(defaultNamespace, redisName) + + gomega.Eventually(func() bool { + err := k8sClient.Get(ctx, types.NamespacedName{Name: mcpServerName, Namespace: defaultNamespace}, &mcpv1alpha1.MCPServer{}) + return apierrors.IsNotFound(err) + }, timeout, pollInterval).Should(gomega.BeTrue()) + }) + + ginkgo.It("Should route session to the correct backend after proxy runner restart", func() { + ginkgo.By("Initializing an MCP session via NodePort") + mcpClient, err := CreateInitializedMCPClient(nodePort, "e2e-backend-routing-test", 30*time.Second) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + + sessionID := mcpClient.Client.GetSessionId() + gomega.Expect(sessionID).NotTo(gomega.BeEmpty(), "session ID must be assigned after Initialize") + + ginkgo.By("Calling tools/list to verify session works before restart") + toolsBefore, err := mcpClient.Client.ListTools(mcpClient.Ctx, mcp.ListToolsRequest{}) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + gomega.Expect(toolsBefore.Tools).NotTo(gomega.BeEmpty()) + + // Cancel the context to stop in-flight requests but do NOT call Close(), + // which would send DELETE and remove the session from Redis. + // This simulates the real proxy-restart scenario: the proxy pod is killed + // mid-session, not the client explicitly terminating. + mcpClient.Cancel() + + ginkgo.By("Getting the current proxy runner pod name") + var pods []corev1.Pod + gomega.Eventually(func() (int, error) { + var listErr error + pods, listErr = getReadyMCPServerPods(mcpServerName, defaultNamespace) + if listErr != nil { + return 0, listErr + } + return len(pods), nil + }, timeout, pollInterval).Should(gomega.Equal(1)) + oldPodName := pods[0].Name + + ginkgo.By(fmt.Sprintf("Deleting proxy runner pod %s (Deployment will recreate it)", oldPodName)) + gomega.Expect(k8sClient.Delete(ctx, &pods[0])).To(gomega.Succeed()) + + ginkgo.By("Waiting for new proxy runner pod to be Running+Ready") + gomega.Eventually(func() (string, error) { + newPods, listErr := getReadyMCPServerPods(mcpServerName, defaultNamespace) + if listErr != nil || len(newPods) == 0 { + return "", fmt.Errorf("waiting for new pod") + } + if newPods[0].Name == oldPodName { + return "", fmt.Errorf("old pod %s still present", oldPodName) + } + return newPods[0].Name, nil + }, timeout, pollInterval).ShouldNot(gomega.BeEmpty()) + + ginkgo.By("Waiting for NodePort to be accessible on the new pod") + gomega.Eventually(func() error { + if err := checkHTTPHealthReady(nodePort, 2*time.Second); err != nil { + return fmt.Errorf("nodePort %d not ready after restart: %w", nodePort, err) + } + return nil + }, timeout, pollInterval).Should(gomega.Succeed()) + + ginkgo.By("Creating a new client with the SAME session ID") + serverURL := fmt.Sprintf("http://localhost:%d/mcp", nodePort) + newClient, err := mcpclient.NewStreamableHttpClient(serverURL, transport.WithSession(sessionID)) + gomega.Expect(err).NotTo(gomega.HaveOccurred()) + defer func() { _ = newClient.Close() }() + + startCtx, startCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer startCancel() + gomega.Expect(newClient.Start(startCtx)).To(gomega.Succeed()) + + // The proxy now stores the actual backend pod IP (captured via httptrace) + // and transparently re-initializes when that pod is unreachable or has lost + // session state. Send 5 requests to give confidence the fix holds: without + // pod-IP pinning and transparent re-init, random ClusterIP routing with 2 + // backends would cause ~97% of these sequences to hit at least one wrong pod. + ginkgo.By("Sending 5 requests with the recovered session to verify backend routing") + for i := range 5 { + listCtx, listCancel := context.WithTimeout(context.Background(), 30*time.Second) + toolsAfter, listErr := newClient.ListTools(listCtx, mcp.ListToolsRequest{}) + listCancel() + gomega.Expect(listErr).NotTo(gomega.HaveOccurred(), + "Request %d/5 should succeed — session should route to the correct backend", i+1) + gomega.Expect(toolsAfter.Tools).To(gomega.HaveLen(len(toolsBefore.Tools)), + "Request %d/5 should return the same tools as before restart", i+1) + } + }) + }) + ginkgo.Context("When MCPServer has replicas=2 with Redis session storage", ginkgo.Ordered, func() { var ( mcpServerName string