Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pkg/vmcp/aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,16 @@ type Aggregator interface {
// toolsByBackend maps backend WorkloadID → raw tools as returned by the backend.
// targets maps backend WorkloadID → the pre-built BackendTarget for that backend.
//
// Returns the advertised tool list (resolved names, filtered) and a routing table
// keyed by resolved name. Each routing table entry has OriginalCapabilityName set
// so that GetBackendCapabilityName() translates back to the raw backend name.
// Returns:
// - advertisedTools: resolved tools that pass the advertising filter (for MCP clients)
// - allResolvedTools: all resolved tools including non-advertised ones (for schema lookup)
// - toolsRouting: routing table keyed by resolved name; each entry has OriginalCapabilityName
// set so that GetBackendCapabilityName() translates back to the raw backend name.
ProcessPreQueriedCapabilities(
ctx context.Context,
toolsByBackend map[string][]vmcp.Tool,
targets map[string]*vmcp.BackendTarget,
) (advertisedTools []vmcp.Tool, toolsRouting map[string]*vmcp.BackendTarget, err error)
) (advertisedTools []vmcp.Tool, allResolvedTools []vmcp.Tool, toolsRouting map[string]*vmcp.BackendTarget, err error)
}

// BackendCapabilities contains the raw capabilities from a single backend.
Expand Down
31 changes: 19 additions & 12 deletions pkg/vmcp/aggregator/default_aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ func (a *defaultAggregator) ProcessPreQueriedCapabilities(
ctx context.Context,
toolsByBackend map[string][]vmcp.Tool,
targets map[string]*vmcp.BackendTarget,
) ([]vmcp.Tool, map[string]*vmcp.BackendTarget, error) {
) ([]vmcp.Tool, []vmcp.Tool, map[string]*vmcp.BackendTarget, error) {
// Step 1: Apply per-backend overrides (renames, description changes).
processed := make(map[string]*BackendCapabilities, len(toolsByBackend))
for backendID, rawTools := range toolsByBackend {
Expand All @@ -515,11 +515,16 @@ func (a *defaultAggregator) ProcessPreQueriedCapabilities(
// Step 2: Resolve naming conflicts across backends.
resolved, err := a.ResolveConflicts(ctx, processed)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

// Step 3: Build advertised list and routing table, applying advertising filter.
// Step 3: Build advertised list, all-resolved list, and routing table.
// advertisedTools is the subset shown to MCP clients (post-filter).
// allResolvedTools includes every resolved tool regardless of advertising filter,
// so that workflow engines can look up InputSchema for type coercion even when
// a backend tool is hidden from clients via excludeAll or filter configuration.
var advertisedTools []vmcp.Tool
var allResolvedTools []vmcp.Tool
routingTable := make(map[string]*vmcp.BackendTarget, len(resolved.Tools))

for _, rt := range resolved.Tools {
Expand All @@ -536,19 +541,21 @@ func (a *defaultAggregator) ProcessPreQueriedCapabilities(
t.OriginalCapabilityName = actualBackendCapabilityName(a.toolConfigMap, rt.BackendID, rt.OriginalName)
routingTable[rt.ResolvedName] = &t

resolved := vmcp.Tool{
Name: rt.ResolvedName,
Description: rt.Description,
InputSchema: rt.InputSchema,
OutputSchema: rt.OutputSchema,
Annotations: rt.Annotations,
BackendID: rt.BackendID,
}
allResolvedTools = append(allResolvedTools, resolved)
if a.shouldAdvertiseTool(rt.BackendID, rt.OriginalName) {
advertisedTools = append(advertisedTools, vmcp.Tool{
Name: rt.ResolvedName,
Description: rt.Description,
InputSchema: rt.InputSchema,
OutputSchema: rt.OutputSchema,
Annotations: rt.Annotations,
BackendID: rt.BackendID,
})
advertisedTools = append(advertisedTools, resolved)
}
}

return advertisedTools, routingTable, nil
return advertisedTools, allResolvedTools, routingTable, nil
}

// actualBackendCapabilityName returns the real capability name the backend uses,
Expand Down
39 changes: 31 additions & 8 deletions pkg/vmcp/aggregator/default_aggregator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
}

agg := NewDefaultAggregator(nil, nil, nil, nil)
advertised, routingTable, err := agg.ProcessPreQueriedCapabilities(
advertised, allResolved, routingTable, err := agg.ProcessPreQueriedCapabilities(
context.Background(), toolsByBackend, targets,
)

Expand All @@ -807,6 +807,9 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
}
assert.Contains(t, advertisedNames, "tool1")
assert.Contains(t, advertisedNames, "tool2")
// With no filter, allResolved must equal the advertised list.
assert.ElementsMatch(t, advertised, allResolved,
"without a filter, allResolvedTools must equal the advertised list")
// Both tools must be in the routing table.
assert.Contains(t, routingTable, "tool1")
assert.Contains(t, routingTable, "tool2")
Expand All @@ -825,7 +828,7 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
}

agg := NewDefaultAggregator(nil, nil, nil, nil)
_, routingTable, err := agg.ProcessPreQueriedCapabilities(
_, _, routingTable, err := agg.ProcessPreQueriedCapabilities(
context.Background(), toolsByBackend, targets,
)

Expand Down Expand Up @@ -858,7 +861,7 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
}

agg := NewDefaultAggregator(nil, nil, aggCfg, nil)
advertised, routingTable, err := agg.ProcessPreQueriedCapabilities(
advertised, _, routingTable, err := agg.ProcessPreQueriedCapabilities(
context.Background(), toolsByBackend, targets,
)

Expand Down Expand Up @@ -892,7 +895,7 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
}

agg := NewDefaultAggregator(nil, nil, nil, nil)
advertised, routingTable, err := agg.ProcessPreQueriedCapabilities(
advertised, _, routingTable, err := agg.ProcessPreQueriedCapabilities(
context.Background(), toolsByBackend, targets,
)

Expand Down Expand Up @@ -926,13 +929,23 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
}

agg := NewDefaultAggregator(nil, nil, aggCfg, nil)
advertised, routingTable, err := agg.ProcessPreQueriedCapabilities(
advertised, allResolved, routingTable, err := agg.ProcessPreQueriedCapabilities(
context.Background(), toolsByBackend, targets,
)

require.NoError(t, err)
assert.Empty(t, advertised,
"ExcludeAllTools must produce an empty advertised list")
// allResolvedTools must contain all tools regardless of the advertising filter,
// so the workflow engine can look up InputSchema for type coercion.
allResolvedNames := make([]string, 0, len(allResolved))
for _, tool := range allResolved {
allResolvedNames = append(allResolvedNames, tool.Name)
}
assert.Contains(t, allResolvedNames, "tool1",
"excluded tools must appear in allResolvedTools for composite tool schema lookup")
assert.Contains(t, allResolvedNames, "tool2",
"excluded tools must appear in allResolvedTools for composite tool schema lookup")
// Tools must still be routable (composite tools need them).
assert.Contains(t, routingTable, "tool1",
"excluded tools must remain in the routing table for composite tool use")
Expand Down Expand Up @@ -963,7 +976,7 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
}

agg := NewDefaultAggregator(nil, nil, aggCfg, nil)
advertised, routingTable, err := agg.ProcessPreQueriedCapabilities(
advertised, allResolved, routingTable, err := agg.ProcessPreQueriedCapabilities(
context.Background(), toolsByBackend, targets,
)

Expand All @@ -975,6 +988,16 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
}
assert.Equal(t, []string{"allowed_tool"}, advertisedNames,
"only tools matching the filter should be advertised")
// allResolvedTools must include both tools so the workflow engine can
// look up InputSchema for type coercion on hidden_tool.
allResolvedNames := make([]string, 0, len(allResolved))
for _, tool := range allResolved {
allResolvedNames = append(allResolvedNames, tool.Name)
}
assert.Contains(t, allResolvedNames, "allowed_tool",
"filtered-in tool must appear in allResolvedTools")
assert.Contains(t, allResolvedNames, "hidden_tool",
"filtered-out tool must appear in allResolvedTools for composite tool schema lookup")
// Both tools remain routable (composite tools can call hidden_tool).
assert.Contains(t, routingTable, "allowed_tool",
"filtered-in tool should be in routing table")
Expand All @@ -995,7 +1018,7 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
}

agg := NewDefaultAggregator(nil, nil, nil, nil)
advertised, routingTable, err := agg.ProcessPreQueriedCapabilities(
advertised, _, routingTable, err := agg.ProcessPreQueriedCapabilities(
context.Background(), toolsByBackend, targets,
)

Expand All @@ -1018,7 +1041,7 @@ func TestDefaultAggregator_ProcessPreQueriedCapabilities(t *testing.T) {
t.Parallel()

agg := NewDefaultAggregator(nil, nil, nil, nil)
advertised, routingTable, err := agg.ProcessPreQueriedCapabilities(
advertised, _, routingTable, err := agg.ProcessPreQueriedCapabilities(
context.Background(),
map[string][]vmcp.Tool{},
map[string]*vmcp.BackendTarget{},
Expand Down
9 changes: 5 additions & 4 deletions pkg/vmcp/aggregator/mocks/mock_interfaces.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkg/vmcp/server/session_management_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func newNoopMockFactory(t *testing.T) *sessionfactorymocks.MockMultiSessionFacto
mock.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes()
mock.EXPECT().SetMetadata(gomock.Any(), gomock.Any()).AnyTimes()
mock.EXPECT().Tools().Return(nil).AnyTimes()
mock.EXPECT().AllTools().Return(nil).AnyTimes()
mock.EXPECT().Resources().Return(nil).AnyTimes()
mock.EXPECT().Prompts().Return(nil).AnyTimes()
mock.EXPECT().BackendSessions().Return(nil).AnyTimes()
Expand Down Expand Up @@ -110,6 +111,7 @@ func newMockFactory(t *testing.T, ctrl *gomock.Controller, tools []vmcp.Tool) (*
toolsCopy := make([]vmcp.Tool, len(tools))
copy(toolsCopy, tools)
mock.EXPECT().Tools().Return(toolsCopy).AnyTimes()
mock.EXPECT().AllTools().Return(toolsCopy).AnyTimes()
mock.EXPECT().Resources().Return(nil).AnyTimes()
mock.EXPECT().Prompts().Return(nil).AnyTimes()
mock.EXPECT().BackendSessions().Return(nil).AnyTimes()
Expand Down
4 changes: 2 additions & 2 deletions pkg/vmcp/server/sessionmanager/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ func compositeToolsDecorator(
}

compositeToolsMeta := compositetools.ConvertWorkflowDefsToTools(sessionDefs)
if err := compositetools.ValidateNoToolConflicts(sess.Tools(), compositeToolsMeta); err != nil {
if err := compositetools.ValidateNoToolConflicts(sess.AllTools(), compositeToolsMeta); err != nil {
slog.Warn("composite tool name conflict detected; skipping composite tools", "session_id", sess.ID(), "error", err)
return sess, nil
}

sessionComposer := composerFactory(sess.GetRoutingTable(), sess.Tools())
sessionComposer := composerFactory(sess.GetRoutingTable(), sess.AllTools())
sessionExecutors := make(map[string]compositetools.WorkflowExecutor, len(sessionDefs))
for _, def := range sessionDefs {
ex := newComposerWorkflowExecutor(sessionComposer, def)
Expand Down
1 change: 1 addition & 0 deletions pkg/vmcp/server/telemetry_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ type backendAwareTestSession struct {
}

func (s *backendAwareTestSession) Tools() []vmcp.Tool { return s.tools }
func (s *backendAwareTestSession) AllTools() []vmcp.Tool { return s.tools }
func (*backendAwareTestSession) Resources() []vmcp.Resource { return nil }
func (*backendAwareTestSession) Prompts() []vmcp.Prompt { return nil }
func (*backendAwareTestSession) BackendSessions() map[string]string { return nil }
Expand Down
13 changes: 11 additions & 2 deletions pkg/vmcp/session/default_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,30 @@ type defaultMultiSession struct {
// All fields below are written once by MakeSession and are read-only thereafter.
connections map[string]backend.Session
routingTable *vmcp.RoutingTable
tools []vmcp.Tool
tools []vmcp.Tool // advertised tools (shown to MCP clients)
allTools []vmcp.Tool // all resolved tools, including non-advertised ones
resources []vmcp.Resource
prompts []vmcp.Prompt
backendSessions map[string]string

queue AdmissionQueue
}

// Tools returns a snapshot copy of the tools available in this session.
// Tools returns a snapshot copy of the advertised tools available in this session.
func (s *defaultMultiSession) Tools() []vmcp.Tool {
result := make([]vmcp.Tool, len(s.tools))
copy(result, s.tools)
return result
}

// AllTools returns a snapshot copy of all resolved tools in this session,
// including tools excluded from advertising to MCP clients.
func (s *defaultMultiSession) AllTools() []vmcp.Tool {
result := make([]vmcp.Tool, len(s.allTools))
copy(result, s.allTools)
return result
}

// Resources returns a snapshot copy of the resources available in this session.
func (s *defaultMultiSession) Resources() []vmcp.Resource {
result := make([]vmcp.Resource, len(s.resources))
Expand Down
29 changes: 18 additions & 11 deletions pkg/vmcp/session/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,14 @@ func buildRoutingTable(results []initResult) (*vmcp.RoutingTable, []vmcp.Tool, [
// pipeline (overrides, conflict resolution, advertising filter) to the raw
// backend capabilities in results, producing resolved tool names identical to
// the standard aggregation path. Resources and prompts pass through unchanged.
//
// Returns the routing table, advertised tools (for MCP clients), all resolved
// tools (for schema lookup), resources, prompts, and any error.
func buildRoutingTableWithAggregator(
ctx context.Context,
agg aggregator.Aggregator,
results []initResult,
) (*vmcp.RoutingTable, []vmcp.Tool, []vmcp.Resource, []vmcp.Prompt, error) {
) (*vmcp.RoutingTable, []vmcp.Tool, []vmcp.Tool, []vmcp.Resource, []vmcp.Prompt, error) {
toolsByBackend := make(map[string][]vmcp.Tool, len(results))
targets := make(map[string]*vmcp.BackendTarget, len(results))
for i := range results {
Expand All @@ -303,9 +306,9 @@ func buildRoutingTableWithAggregator(
targets[r.target.WorkloadID] = r.target
}

allTools, toolsRouting, err := agg.ProcessPreQueriedCapabilities(ctx, toolsByBackend, targets)
advertisedTools, allResolvedTools, toolsRouting, err := agg.ProcessPreQueriedCapabilities(ctx, toolsByBackend, targets)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, err
}

rt := &vmcp.RoutingTable{
Expand All @@ -331,7 +334,7 @@ func buildRoutingTableWithAggregator(
}
}

return rt, allTools, allResources, allPrompts, nil
return rt, advertisedTools, allResolvedTools, allResources, allPrompts, nil
}

// MakeSessionWithID implements MultiSessionFactory.
Expand Down Expand Up @@ -458,19 +461,22 @@ func (f *defaultMultiSessionFactory) makeBaseSession(
}

var (
routingTable *vmcp.RoutingTable
allTools []vmcp.Tool
allResources []vmcp.Resource
allPrompts []vmcp.Prompt
routingTable *vmcp.RoutingTable
advertisedTools []vmcp.Tool
allResolvedTools []vmcp.Tool
allResources []vmcp.Resource
allPrompts []vmcp.Prompt
)
if f.aggregator != nil {
var aggErr error
routingTable, allTools, allResources, allPrompts, aggErr = buildRoutingTableWithAggregator(ctx, f.aggregator, results)
routingTable, advertisedTools, allResolvedTools, allResources, allPrompts, aggErr =
buildRoutingTableWithAggregator(ctx, f.aggregator, results)
if aggErr != nil {
return nil, fmt.Errorf("failed to process backend capabilities: %w", aggErr)
}
} else {
routingTable, allTools, allResources, allPrompts = buildRoutingTable(results)
routingTable, advertisedTools, allResources, allPrompts = buildRoutingTable(results)
allResolvedTools = advertisedTools // no filter when no aggregator
}

transportSess := transportsession.NewStreamableSession(sessID)
Expand All @@ -483,7 +489,8 @@ func (f *defaultMultiSessionFactory) makeBaseSession(
Session: transportSess,
connections: connections,
routingTable: routingTable,
tools: allTools,
tools: advertisedTools,
allTools: allResolvedTools,
resources: allResources,
prompts: allPrompts,
backendSessions: backendSessions,
Expand Down
14 changes: 14 additions & 0 deletions pkg/vmcp/session/types/mocks/mock_session.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion pkg/vmcp/session/types/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,15 @@ type MultiSession interface {
transportsession.Session
Caller

// Tools returns the resolved tools available in this session.
// Tools returns the advertised tools available in this session (shown to MCP clients).
// The list is built once at session creation and is read-only thereafter.
Tools() []vmcp.Tool

// AllTools returns all resolved tools in this session, including tools that are
// excluded from advertising to MCP clients via excludeAll or filter configuration.
// Used by the workflow engine for argument type coercion via InputSchema lookup.
AllTools() []vmcp.Tool

// Resources returns the resolved resources available in this session.
Resources() []vmcp.Resource

Expand Down
Loading