Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 359 additions & 0 deletions toolrunner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,362 @@ func TestToolRunner_MalformedJSONInput(t *testing.T) {
t.Fatal("expected error for invalid JSON in Execute")
}
}

// TestToolRunner_SchemaValidation verifies that the tool runner validates inputs
// against the JSON Schema before executing the handler. This prevents missing
// required fields, enum violations, and type mismatches from reaching handlers.
func TestToolRunner_SchemaValidation(t *testing.T) {
t.Parallel()

type StrictInput struct {
City string `json:"city"`
Units string `json:"units,omitempty"`
}

handlerCalled := false
tool, err := toolrunner.NewBetaToolFromBytes("weather", "Get weather", schemaToBytes(t, weatherSchema),
func(ctx context.Context, input StrictInput) (anthropic.BetaToolResultBlockParamContentUnion, error) {
handlerCalled = true
return anthropic.BetaToolResultBlockParamContentUnion{
OfText: &anthropic.BetaTextBlockParam{Text: fmt.Sprintf("Weather in %s (%s)", input.City, input.Units)},
}, nil
})
if err != nil {
t.Fatalf("create tool: %v", err)
}

t.Run("valid input passes validation", func(t *testing.T) {
handlerCalled = false
input := json.RawMessage(`{"city": "London", "units": "celsius"}`)
result, err := tool.Execute(context.Background(), input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !handlerCalled {
t.Fatal("handler was not called for valid input")
}
if result.OfText == nil || result.OfText.Text != "Weather in London (celsius)" {
t.Fatalf("unexpected result: %+v", result)
}
})

t.Run("missing required field rejected", func(t *testing.T) {
handlerCalled = false
// "city" is required but missing
input := json.RawMessage(`{"units": "celsius"}`)
_, err := tool.Execute(context.Background(), input)
if err == nil {
t.Fatal("expected error for missing required field 'city', got nil")
}
if handlerCalled {
t.Fatal("handler should NOT be called when schema validation fails")
}
if !strings.Contains(err.Error(), "schema validation failed") {
t.Fatalf("error should mention schema validation, got: %v", err)
}
})

t.Run("enum violation rejected", func(t *testing.T) {
handlerCalled = false
// "units" must be "celsius" or "fahrenheit"
input := json.RawMessage(`{"city": "London", "units": "kelvin"}`)
_, err := tool.Execute(context.Background(), input)
if err == nil {
t.Fatal("expected error for enum violation on 'units', got nil")
}
if handlerCalled {
t.Fatal("handler should NOT be called when schema validation fails")
}
if !strings.Contains(err.Error(), "schema validation failed") {
t.Fatalf("error should mention schema validation, got: %v", err)
}
})

t.Run("wrong type rejected", func(t *testing.T) {
handlerCalled = false
// "city" should be string, not number
input := json.RawMessage(`{"city": 12345}`)
_, err := tool.Execute(context.Background(), input)
if err == nil {
t.Fatal("expected error for wrong type on 'city', got nil")
}
if handlerCalled {
t.Fatal("handler should NOT be called when schema validation fails")
}
})

t.Run("empty object rejected when required fields exist", func(t *testing.T) {
handlerCalled = false
input := json.RawMessage(`{}`)
_, err := tool.Execute(context.Background(), input)
if err == nil {
t.Fatal("expected error for empty object with required fields, got nil")
}
if handlerCalled {
t.Fatal("handler should NOT be called when schema validation fails")
}
})

t.Run("optional field can be omitted", func(t *testing.T) {
handlerCalled = false
// "units" is optional, only "city" is required
input := json.RawMessage(`{"city": "Tokyo"}`)
_, err := tool.Execute(context.Background(), input)
if err != nil {
t.Fatalf("unexpected error for valid input without optional field: %v", err)
}
if !handlerCalled {
t.Fatal("handler was not called for valid input")
}
})
}

// TestToolRunner_AdditionalPropertiesRejected verifies that additionalProperties:false
// blocks unknown keys from reaching the handler.
func TestToolRunner_AdditionalPropertiesRejected(t *testing.T) {
t.Parallel()

type StrictInput struct {
Name string `json:"name"`
}

schema := map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{"type": "string"},
},
"required": []string{"name"},
"additionalProperties": false,
}

handlerCalled := false
tool, err := toolrunner.NewBetaToolFromBytes("strict", "Strict tool", schemaToBytes(t, schema),
func(ctx context.Context, input StrictInput) (anthropic.BetaToolResultBlockParamContentUnion, error) {
handlerCalled = true
return anthropic.BetaToolResultBlockParamContentUnion{
OfText: &anthropic.BetaTextBlockParam{Text: "ok"},
}, nil
})
if err != nil {
t.Fatalf("create tool: %v", err)
}

t.Run("valid input accepted", func(t *testing.T) {
handlerCalled = false
input := json.RawMessage(`{"name": "test"}`)
_, err := tool.Execute(context.Background(), input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !handlerCalled {
t.Fatal("handler was not called")
}
})

t.Run("extra property rejected", func(t *testing.T) {
handlerCalled = false
input := json.RawMessage(`{"name": "test", "extra": "x"}`)
_, err := tool.Execute(context.Background(), input)
if err == nil {
t.Fatal("expected error for additional property, got nil")
}
if handlerCalled {
t.Fatal("handler should NOT be called when additionalProperties is violated")
}
if !strings.Contains(err.Error(), "additional property") {
t.Fatalf("error should mention additional property, got: %v", err)
}
})
}

// TestToolRunner_PatternValidation verifies that pattern constraints on string
// properties are enforced at runtime.
func TestToolRunner_PatternValidation(t *testing.T) {
t.Parallel()

type URLInput struct {
URL string `json:"url"`
}

schema := map[string]any{
"type": "object",
"properties": map[string]any{
"url": map[string]any{
"type": "string",
"pattern": `^https://allowed\.example/`,
},
},
"required": []string{"url"},
}

handlerCalled := false
tool, err := toolrunner.NewBetaToolFromBytes("url_tool", "URL tool", schemaToBytes(t, schema),
func(ctx context.Context, input URLInput) (anthropic.BetaToolResultBlockParamContentUnion, error) {
handlerCalled = true
return anthropic.BetaToolResultBlockParamContentUnion{
OfText: &anthropic.BetaTextBlockParam{Text: "ok"},
}, nil
})
if err != nil {
t.Fatalf("create tool: %v", err)
}

t.Run("matching pattern accepted", func(t *testing.T) {
handlerCalled = false
input := json.RawMessage(`{"url": "https://allowed.example/page"}`)
_, err := tool.Execute(context.Background(), input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !handlerCalled {
t.Fatal("handler was not called")
}
})

t.Run("non-matching pattern rejected", func(t *testing.T) {
handlerCalled = false
input := json.RawMessage(`{"url": "https://evil.example/attack"}`)
_, err := tool.Execute(context.Background(), input)
if err == nil {
t.Fatal("expected error for pattern violation, got nil")
}
if handlerCalled {
t.Fatal("handler should NOT be called when pattern is violated")
}
if !strings.Contains(err.Error(), "pattern") {
t.Fatalf("error should mention pattern, got: %v", err)
}
})
}

// TestToolRunner_StringLengthValidation verifies minLength and maxLength enforcement.
func TestToolRunner_StringLengthValidation(t *testing.T) {
t.Parallel()

type NameInput struct {
Name string `json:"name"`
}

schema := map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"minLength": 2,
"maxLength": 10,
},
},
"required": []string{"name"},
}

handlerCalled := false
tool, err := toolrunner.NewBetaToolFromBytes("name_tool", "Name tool", schemaToBytes(t, schema),
func(ctx context.Context, input NameInput) (anthropic.BetaToolResultBlockParamContentUnion, error) {
handlerCalled = true
return anthropic.BetaToolResultBlockParamContentUnion{
OfText: &anthropic.BetaTextBlockParam{Text: "ok"},
}, nil
})
if err != nil {
t.Fatalf("create tool: %v", err)
}

t.Run("valid length accepted", func(t *testing.T) {
handlerCalled = false
_, err := tool.Execute(context.Background(), json.RawMessage(`{"name": "Alice"}`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !handlerCalled {
t.Fatal("handler was not called")
}
})

t.Run("too short rejected", func(t *testing.T) {
handlerCalled = false
_, err := tool.Execute(context.Background(), json.RawMessage(`{"name": "A"}`))
if err == nil {
t.Fatal("expected error for minLength violation")
}
if handlerCalled {
t.Fatal("handler should NOT be called")
}
})

t.Run("too long rejected", func(t *testing.T) {
handlerCalled = false
_, err := tool.Execute(context.Background(), json.RawMessage(`{"name": "VeryLongNameHere"}`))
if err == nil {
t.Fatal("expected error for maxLength violation")
}
if handlerCalled {
t.Fatal("handler should NOT be called")
}
})
}

// TestToolRunner_NumericBoundsValidation verifies minimum and maximum enforcement.
func TestToolRunner_NumericBoundsValidation(t *testing.T) {
t.Parallel()

type AgeInput struct {
Age int `json:"age"`
}

schema := map[string]any{
"type": "object",
"properties": map[string]any{
"age": map[string]any{
"type": "integer",
"minimum": 0,
"maximum": 150,
},
},
"required": []string{"age"},
}

handlerCalled := false
tool, err := toolrunner.NewBetaToolFromBytes("age_tool", "Age tool", schemaToBytes(t, schema),
func(ctx context.Context, input AgeInput) (anthropic.BetaToolResultBlockParamContentUnion, error) {
handlerCalled = true
return anthropic.BetaToolResultBlockParamContentUnion{
OfText: &anthropic.BetaTextBlockParam{Text: "ok"},
}, nil
})
if err != nil {
t.Fatalf("create tool: %v", err)
}

t.Run("valid value accepted", func(t *testing.T) {
handlerCalled = false
_, err := tool.Execute(context.Background(), json.RawMessage(`{"age": 25}`))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !handlerCalled {
t.Fatal("handler was not called")
}
})

t.Run("below minimum rejected", func(t *testing.T) {
handlerCalled = false
_, err := tool.Execute(context.Background(), json.RawMessage(`{"age": -1}`))
if err == nil {
t.Fatal("expected error for minimum violation")
}
if handlerCalled {
t.Fatal("handler should NOT be called")
}
})

t.Run("above maximum rejected", func(t *testing.T) {
handlerCalled = false
_, err := tool.Execute(context.Background(), json.RawMessage(`{"age": 200}`))
if err == nil {
t.Fatal("expected error for maximum violation")
}
if handlerCalled {
t.Fatal("handler should NOT be called")
}
})
}
Loading
Loading