diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..4a524b8281 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,27 @@ +Release type: minor + +This release improves schema codegen to generate Python stub classes for custom directive definitions. Previously, schemas containing custom directives would fail with `NotImplementedError`. Now directive definitions are converted to Strawberry schema directive classes. + +For example, this GraphQL schema: +```graphql +directive @authz(resource: String!, action: String!) on FIELD_DEFINITION + +type Query { + hello: String! @authz(resource: "greeting", action: "read") +} +``` + +Now generates: +```python +from strawberry.schema_directive import Location + + +@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) +class Authz: + resource: str + action: str +``` + +Note: The generated directives are stubs - they don't contain any behavior logic, which must be implemented separately. + +This also fixes the error message for unknown definition types to show the actual type name instead of `None`. diff --git a/strawberry/schema_codegen/__init__.py b/strawberry/schema_codegen/__init__.py index 3c8ceb2018..31447a10b5 100644 --- a/strawberry/schema_codegen/__init__.py +++ b/strawberry/schema_codegen/__init__.py @@ -9,6 +9,7 @@ import libcst as cst from graphql import ( + DirectiveDefinitionNode, EnumTypeDefinitionNode, EnumValueDefinitionNode, FieldDefinitionNode, @@ -595,6 +596,83 @@ def _get_union_definition(definition: UnionTypeDefinitionNode) -> Definition: ) +def _get_directive_definition( + definition: DirectiveDefinitionNode, imports: set[Import] +) -> Definition: + """Generate a stub for a custom directive definition.""" + # Get the directive name (will be in camelCase) + directive_name = definition.name.value + # Class name should be PascalCase + class_name = directive_name[0].upper() + directive_name[1:] + + # Build list of Location.X attributes (locations are required per GraphQL spec) + location_elements = [ + cst.Element( + value=cst.Attribute( + value=cst.Name("Location"), attr=cst.Name(location.value) + ) + ) + for location in definition.locations + ] + + decorator_args = [ + cst.Arg( + keyword=cst.Name("locations"), + value=cst.List(elements=location_elements), + equal=cst.AssignEqual(cst.SimpleWhitespace(""), cst.SimpleWhitespace("")), + ) + ] + + if definition.description: + decorator_args.append( + _get_argument("description", definition.description.value) + ) + + # Build decorator + decorator = cst.Decorator( + decorator=cst.Call( + func=cst.Attribute( + value=cst.Name("strawberry"), + attr=cst.Name("schema_directive"), + ), + args=decorator_args, + ) + ) + + # Build fields for the directive class + fields = [] + if definition.arguments: + for arg in definition.arguments: + field_name = to_snake_case(arg.name.value) + field_type = _get_field_type(arg.type) + fields.append( + cst.SimpleStatementLine( + body=[ + cst.AnnAssign( + target=cst.Name(field_name), + annotation=cst.Annotation(field_type), + ) + ] + ) + ) + + # If no fields, add pass statement + if not fields: + fields = [cst.SimpleStatementLine(body=[cst.Pass()])] + + # Build class definition + class_def = cst.ClassDef( + name=cst.Name(class_name), + body=cst.IndentedBlock(body=fields), + decorators=[decorator], + ) + + # Add necessary imports + imports.add(Import(module="strawberry.schema_directive", imports=("Location",))) + + return Definition(class_def, [], class_name) + + def _get_scalar_definition( definition: ScalarTypeDefinitionNode, imports: set[Import] ) -> Definition | None: @@ -748,8 +826,12 @@ def codegen(schema: str) -> str: _is_federation_link_directive(directive) for directive in graphql_definition.directives ) + elif isinstance(graphql_definition, DirectiveDefinitionNode): + definition = _get_directive_definition(graphql_definition, imports) else: - raise NotImplementedError(f"Unknown definition {definition}") + raise NotImplementedError( + f"Unknown definition {type(graphql_definition).__name__}" + ) if definition is not None: definitions[definition.name] = definition diff --git a/tests/cli/test_schema_codegen.py b/tests/cli/test_schema_codegen.py index 9d3e3286a5..b09e70abfa 100644 --- a/tests/cli/test_schema_codegen.py +++ b/tests/cli/test_schema_codegen.py @@ -5,13 +5,21 @@ from typer.testing import CliRunner schema = """ +directive @authz(resource: String!, action: String!) on FIELD_DEFINITION + type Query { - hello: String! + hello: String! @authz(resource: "greeting", action: "read") } """ expected_output = """ import strawberry +from strawberry.schema_directive import Location + +@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) +class Authz: + resource: str + action: str @strawberry.type class Query: diff --git a/tests/schema_codegen/test_directive.py b/tests/schema_codegen/test_directive.py new file mode 100644 index 0000000000..1cf31896be --- /dev/null +++ b/tests/schema_codegen/test_directive.py @@ -0,0 +1,121 @@ +import textwrap + +from strawberry.schema_codegen import codegen + + +def test_directive_with_arguments(): + schema = """ + directive @authz(resource: String!, action: String!) on FIELD_DEFINITION + + type Query { + hello: String! + } + """ + + expected = textwrap.dedent( + """ + import strawberry + from strawberry.schema_directive import Location + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Authz: + resource: str + action: str + + @strawberry.type + class Query: + hello: str + + schema = strawberry.Schema(query=Query) + """ + ).strip() + + assert codegen(schema).strip() == expected + + +def test_directive_without_arguments(): + schema = """ + directive @deprecated on FIELD_DEFINITION + + type Query { + hello: String! + } + """ + + expected = textwrap.dedent( + """ + import strawberry + from strawberry.schema_directive import Location + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Deprecated: + pass + + @strawberry.type + class Query: + hello: str + + schema = strawberry.Schema(query=Query) + """ + ).strip() + + assert codegen(schema).strip() == expected + + +def test_directive_with_description(): + schema = ''' + """Authorization directive for field-level access control""" + directive @authz(resource: String!) on FIELD_DEFINITION + + type Query { + hello: String! + } + ''' + + expected = textwrap.dedent( + """ + import strawberry + from strawberry.schema_directive import Location + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION], description="Authorization directive for field-level access control") + class Authz: + resource: str + + @strawberry.type + class Query: + hello: str + + schema = strawberry.Schema(query=Query) + """ + ).strip() + + assert codegen(schema).strip() == expected + + +def test_directive_with_multiple_locations(): + schema = """ + directive @example on FIELD_DEFINITION | OBJECT | INTERFACE + + type Query { + hello: String! + } + """ + + expected = textwrap.dedent( + """ + import strawberry + from strawberry.schema_directive import Location + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION, Location.OBJECT, Location.INTERFACE]) + class Example: + pass + + @strawberry.type + class Query: + hello: str + + schema = strawberry.Schema(query=Query) + """ + ).strip() + + assert codegen(schema).strip() == expected diff --git a/tests/schema_codegen/test_types.py b/tests/schema_codegen/test_types.py index f2beadcc33..b6e9f8f5db 100644 --- a/tests/schema_codegen/test_types.py +++ b/tests/schema_codegen/test_types.py @@ -1,5 +1,7 @@ import textwrap +import pytest + from strawberry.schema_codegen import codegen @@ -76,6 +78,22 @@ class Example: assert codegen(schema).strip() == expected +def test_codegen_raises_for_unknown_definition(): + # GraphQL operation definitions (queries) are not supported in schema codegen + schema = """ + query GetUser { + user { + name + } + } + """ + + with pytest.raises( + NotImplementedError, match="Unknown definition OperationDefinitionNode" + ): + codegen(schema) + + def test_supports_interfaces(): schema = """ interface Node {