Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 228 additions & 0 deletions pkg/transport/proxy/transparent/backend_routing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -289,3 +290,230 @@ func TestRoundTripStoresBackendURLOnInitialize(t *testing.T) {
require.True(t, ok, "session should have backend_url metadata")
assert.Equal(t, backend.URL, backendURL)
}

// TestRoundTripStoresInitBodyOnInitialize verifies that the raw JSON-RPC initialize
// request body is stored in session metadata so the proxy can transparently
// re-initialize the backend session if the pod is later replaced.
func TestRoundTripStoresInitBodyOnInitialize(t *testing.T) {
t.Parallel()

sessionID := uuid.New().String()
const initBody = `{"jsonrpc":"2.0","id":1,"method":"initialize"}`
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Mcp-Session-Id", sessionID)
w.WriteHeader(http.StatusOK)
}))
defer backend.Close()

proxy, addr := startProxy(t, backend.URL)

ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
"http://"+addr+"/mcp",
strings.NewReader(initBody))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
_ = resp.Body.Close()

sess, ok := proxy.sessionManager.Get(normalizeSessionID(sessionID))
require.True(t, ok, "session should have been created")
stored, exists := sess.GetMetadataValue(sessionMetadataInitBody)
require.True(t, exists, "init_body should be stored in session metadata")
assert.Equal(t, initBody, stored)
}

// TestRoundTripReinitializesOnBackend404 verifies that when the backend pod returns
// 404 (session lost after restart on the same IP), the proxy transparently
// re-initializes the backend session and replays the original request — client sees 200.
func TestRoundTripReinitializesOnBackend404(t *testing.T) {
t.Parallel()

// staleBackend simulates a pod that has lost its in-memory session state.
var staleHit atomic.Int32
staleBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
staleHit.Add(1)
w.WriteHeader(http.StatusNotFound)
}))
defer staleBackend.Close()

// freshBackend simulates a healthy pod: returns a session ID on initialize
// and 200 for all other requests.
freshSessionID := uuid.New().String()
var freshHit atomic.Int32
freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
freshHit.Add(1)
body, _ := io.ReadAll(r.Body)
if strings.Contains(string(body), `"initialize"`) {
w.Header().Set("Mcp-Session-Id", freshSessionID)
}
w.WriteHeader(http.StatusOK)
}))
defer freshBackend.Close()

// targetURI (ClusterIP) points to freshBackend; the session has staleBackend as backend_url.
proxy, addr := startProxy(t, freshBackend.URL)

clientSessionID := uuid.New().String()
sess := session.NewProxySession(clientSessionID)
sess.SetMetadata(sessionMetadataBackendURL, staleBackend.URL)
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
require.NoError(t, proxy.sessionManager.AddSession(sess))

ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
"http://"+addr+"/mcp",
strings.NewReader(`{"method":"tools/list"}`))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", clientSessionID)

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
_ = resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode, "client should see 200 after transparent re-init")
assert.GreaterOrEqual(t, staleHit.Load(), int32(1), "stale backend should have been hit")
assert.GreaterOrEqual(t, freshHit.Load(), int32(2), "fresh backend should receive initialize + replay")

// Session should now have backend_sid mapping to the new backend session.
updated, ok := proxy.sessionManager.Get(normalizeSessionID(clientSessionID))
require.True(t, ok, "session should still exist after re-init")
backendSID, exists := updated.GetMetadataValue(sessionMetadataBackendSID)
require.True(t, exists, "backend_sid should be set after re-init")
assert.Equal(t, freshSessionID, backendSID, "backend_sid must be the raw value the backend issued, not normalized")
}

// TestRoundTripReinitializesPreservesNonUUIDBackendSessionID verifies that when the
// backend issues a non-UUID Mcp-Session-Id on re-initialization, the proxy stores
// and forwards the raw value — not a UUID v5 hash of it — on all subsequent requests.
//
// The normalization bug only manifests on the request AFTER the replay: the replay
// sets Mcp-Session-Id directly from newBackendSID (bypassing Rewrite), but subsequent
// requests go through the Rewrite closure which reads backend_sid from session metadata.
// If backend_sid was stored as normalizeSessionID(newBackendSID), Rewrite would send
// the wrong (hashed) value and the backend would reject every subsequent request.
func TestRoundTripReinitializesPreservesNonUUIDBackendSessionID(t *testing.T) {
t.Parallel()

// Non-UUID opaque token, as some MCP servers issue.
const nonUUIDSessionID = "opaque-session-token-abc123"

staleBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer staleBackend.Close()

// receivedSIDs tracks Mcp-Session-Id values arriving on non-initialize requests,
// in order. Index 0 = replay (direct from reinitializeAndReplay), index 1 = second
// client request (routed through Rewrite reading backend_sid from session metadata).
var (
receivedMu sync.Mutex
receivedSIDs []string
)
freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
if strings.Contains(string(body), `"initialize"`) {
w.Header().Set("Mcp-Session-Id", nonUUIDSessionID)
w.WriteHeader(http.StatusOK)
return
}
receivedMu.Lock()
receivedSIDs = append(receivedSIDs, r.Header.Get("Mcp-Session-Id"))
receivedMu.Unlock()
w.WriteHeader(http.StatusOK)
}))
defer freshBackend.Close()

proxy, addr := startProxy(t, freshBackend.URL)

clientSessionID := uuid.New().String()
sess := session.NewProxySession(clientSessionID)
sess.SetMetadata(sessionMetadataBackendURL, staleBackend.URL)
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
require.NoError(t, proxy.sessionManager.AddSession(sess))

doRequest := func() *http.Response {
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
"http://"+addr+"/mcp",
strings.NewReader(`{"method":"tools/list"}`))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", clientSessionID)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
return resp
}

// First request: triggers re-init. The replay (inside reinitializeAndReplay) sets
// Mcp-Session-Id directly, so receivedSIDs[0] is always the raw value regardless
// of what is stored in session metadata.
resp1 := doRequest()
_ = resp1.Body.Close()
require.Equal(t, http.StatusOK, resp1.StatusCode)

// Second request: goes through the Rewrite closure, which reads backend_sid from
// session metadata. This is where the normalization bug manifests — if backend_sid
// was stored as normalizeSessionID(nonUUIDSessionID), Rewrite would forward the
// wrong hashed value and receivedSIDs[1] would not equal nonUUIDSessionID.
resp2 := doRequest()
_ = resp2.Body.Close()
require.Equal(t, http.StatusOK, resp2.StatusCode)

receivedMu.Lock()
defer receivedMu.Unlock()
require.Len(t, receivedSIDs, 2, "fresh backend should have received replay + second request")
assert.Equal(t, nonUUIDSessionID, receivedSIDs[0], "replay must forward raw non-UUID session ID")
assert.Equal(t, nonUUIDSessionID, receivedSIDs[1], "subsequent request via Rewrite must forward raw non-UUID session ID")
}

// TestRoundTripReinitializesOnDialError verifies that when the proxy cannot reach
// the stored pod IP (dial error — pod rescheduled to a new IP), it transparently
// re-initializes the backend session via the ClusterIP and replays the original
// request — the client sees a 200.
func TestRoundTripReinitializesOnDialError(t *testing.T) {
t.Parallel()

// Create a server and immediately close it so its URL refuses connections.
dead := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
deadURL := dead.URL
dead.Close()

freshSessionID := uuid.New().String()
var freshHit atomic.Int32
freshBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
freshHit.Add(1)
body, _ := io.ReadAll(r.Body)
if strings.Contains(string(body), `"initialize"`) {
w.Header().Set("Mcp-Session-Id", freshSessionID)
}
w.WriteHeader(http.StatusOK)
}))
defer freshBackend.Close()

proxy, addr := startProxy(t, freshBackend.URL)

clientSessionID := uuid.New().String()
sess := session.NewProxySession(clientSessionID)
sess.SetMetadata(sessionMetadataBackendURL, deadURL)
sess.SetMetadata(sessionMetadataInitBody, `{"jsonrpc":"2.0","id":1,"method":"initialize"}`)
require.NoError(t, proxy.sessionManager.AddSession(sess))

ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
"http://"+addr+"/mcp",
strings.NewReader(`{"method":"tools/list"}`))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Mcp-Session-Id", clientSessionID)

resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
_ = resp.Body.Close()

assert.Equal(t, http.StatusOK, resp.StatusCode, "client should see 200 after transparent re-init on dial error")
assert.GreaterOrEqual(t, freshHit.Load(), int32(2), "fresh backend should receive initialize + replay")
}
Loading
Loading