Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Release type: minor

This release improves schema codegen to generate Python stub classes for custom directive definitions. Previously, schemas containing custom directives would fail with `NotImplementedError`. Now directive definitions are converted to Strawberry schema directive classes.

For example, this GraphQL schema:
```graphql
directive @authz(resource: String!, action: String!) on FIELD_DEFINITION

type Query {
hello: String! @authz(resource: "greeting", action: "read")
}
```

Now generates:
```python
from strawberry.schema_directive import Location


@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
class Authz:
resource: str
action: str
```

Note: The generated directives are stubs - they don't contain any behavior logic, which must be implemented separately.

This also fixes the error message for unknown definition types to show the actual type name instead of `None`.
84 changes: 83 additions & 1 deletion strawberry/schema_codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import libcst as cst
from graphql import (
DirectiveDefinitionNode,
EnumTypeDefinitionNode,
EnumValueDefinitionNode,
FieldDefinitionNode,
Expand Down Expand Up @@ -595,6 +596,83 @@ def _get_union_definition(definition: UnionTypeDefinitionNode) -> Definition:
)


def _get_directive_definition(
definition: DirectiveDefinitionNode, imports: set[Import]
) -> Definition:
"""Generate a stub for a custom directive definition."""
# Get the directive name (will be in camelCase)
directive_name = definition.name.value
# Class name should be PascalCase
class_name = directive_name[0].upper() + directive_name[1:]

# Build list of Location.X attributes (locations are required per GraphQL spec)
location_elements = [
cst.Element(
value=cst.Attribute(
value=cst.Name("Location"), attr=cst.Name(location.value)
)
)
for location in definition.locations
]

decorator_args = [
cst.Arg(
keyword=cst.Name("locations"),
value=cst.List(elements=location_elements),
equal=cst.AssignEqual(cst.SimpleWhitespace(""), cst.SimpleWhitespace("")),
)
]

if definition.description:
decorator_args.append(
_get_argument("description", definition.description.value)
)

# Build decorator
decorator = cst.Decorator(
decorator=cst.Call(
func=cst.Attribute(
value=cst.Name("strawberry"),
attr=cst.Name("schema_directive"),
),
args=decorator_args,
)
)

# Build fields for the directive class
fields = []
if definition.arguments:
for arg in definition.arguments:
field_name = to_snake_case(arg.name.value)
field_type = _get_field_type(arg.type)
fields.append(
cst.SimpleStatementLine(
body=[
cst.AnnAssign(
target=cst.Name(field_name),
annotation=cst.Annotation(field_type),
)
]
)
)

# If no fields, add pass statement
if not fields:
fields = [cst.SimpleStatementLine(body=[cst.Pass()])]

# Build class definition
class_def = cst.ClassDef(
name=cst.Name(class_name),
body=cst.IndentedBlock(body=fields),
decorators=[decorator],
)

# Add necessary imports
imports.add(Import(module="strawberry.schema_directive", imports=("Location",)))

return Definition(class_def, [], class_name)


def _get_scalar_definition(
definition: ScalarTypeDefinitionNode, imports: set[Import]
) -> Definition | None:
Expand Down Expand Up @@ -748,8 +826,12 @@ def codegen(schema: str) -> str:
_is_federation_link_directive(directive)
for directive in graphql_definition.directives
)
elif isinstance(graphql_definition, DirectiveDefinitionNode):
definition = _get_directive_definition(graphql_definition, imports)
else:
raise NotImplementedError(f"Unknown definition {definition}")
raise NotImplementedError(
f"Unknown definition {type(graphql_definition).__name__}"
)

if definition is not None:
definitions[definition.name] = definition
Expand Down
10 changes: 9 additions & 1 deletion tests/cli/test_schema_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@
from typer.testing import CliRunner

schema = """
directive @authz(resource: String!, action: String!) on FIELD_DEFINITION

type Query {
hello: String!
hello: String! @authz(resource: "greeting", action: "read")
}
"""

expected_output = """
import strawberry
from strawberry.schema_directive import Location

@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
class Authz:
resource: str
action: str

@strawberry.type
class Query:
Expand Down
121 changes: 121 additions & 0 deletions tests/schema_codegen/test_directive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import textwrap

from strawberry.schema_codegen import codegen


def test_directive_with_arguments():
schema = """
directive @authz(resource: String!, action: String!) on FIELD_DEFINITION

type Query {
hello: String!
}
"""

expected = textwrap.dedent(
"""
import strawberry
from strawberry.schema_directive import Location

@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
class Authz:
resource: str
action: str

@strawberry.type
class Query:
hello: str

schema = strawberry.Schema(query=Query)
"""
).strip()

assert codegen(schema).strip() == expected


def test_directive_without_arguments():
schema = """
directive @deprecated on FIELD_DEFINITION

type Query {
hello: String!
}
"""

expected = textwrap.dedent(
"""
import strawberry
from strawberry.schema_directive import Location

@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
class Deprecated:
pass

@strawberry.type
class Query:
hello: str

schema = strawberry.Schema(query=Query)
"""
).strip()

assert codegen(schema).strip() == expected


def test_directive_with_description():
schema = '''
"""Authorization directive for field-level access control"""
directive @authz(resource: String!) on FIELD_DEFINITION

type Query {
hello: String!
}
'''

expected = textwrap.dedent(
"""
import strawberry
from strawberry.schema_directive import Location

@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION], description="Authorization directive for field-level access control")
class Authz:
resource: str

@strawberry.type
class Query:
hello: str

schema = strawberry.Schema(query=Query)
"""
).strip()

assert codegen(schema).strip() == expected


def test_directive_with_multiple_locations():
schema = """
directive @example on FIELD_DEFINITION | OBJECT | INTERFACE

type Query {
hello: String!
}
"""

expected = textwrap.dedent(
"""
import strawberry
from strawberry.schema_directive import Location

@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION, Location.OBJECT, Location.INTERFACE])
class Example:
pass

@strawberry.type
class Query:
hello: str

schema = strawberry.Schema(query=Query)
"""
).strip()

assert codegen(schema).strip() == expected
18 changes: 18 additions & 0 deletions tests/schema_codegen/test_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import textwrap

import pytest

from strawberry.schema_codegen import codegen


Expand Down Expand Up @@ -76,6 +78,22 @@ class Example:
assert codegen(schema).strip() == expected


def test_codegen_raises_for_unknown_definition():
# GraphQL operation definitions (queries) are not supported in schema codegen
schema = """
query GetUser {
user {
name
}
}
"""

with pytest.raises(
NotImplementedError, match="Unknown definition OperationDefinitionNode"
):
codegen(schema)


def test_supports_interfaces():
schema = """
interface Node {
Expand Down
Loading