Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 80 additions & 20 deletions gql/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from graphql import (
ArgumentNode,
BooleanValueNode,
ConstDirectiveNode,
DirectiveLocation,
DirectiveNode,
DocumentNode,
Expand Down Expand Up @@ -433,7 +434,10 @@ def args(self, **kwargs: Any) -> Self:
arguments=tuple(
ArgumentNode(
name=NameNode(value=key),
value=ast_from_value(value, self.directive_def.args[key].type),
value=cast(
ValueNode,
ast_from_value(value, self.directive_def.args[key].type),
),
)
for key, value in kwargs.items()
),
Expand Down Expand Up @@ -596,7 +600,13 @@ def alias(self, alias: str) -> Self:
:return: itself
"""

self.ast_field.alias = NameNode(value=alias)
self.ast_field = FieldNode(
name=self.ast_field.name,
alias=NameNode(value=alias),
arguments=self.ast_field.arguments,
directives=self.ast_field.directives,
selection_set=self.ast_field.selection_set,
)
return self


Expand Down Expand Up @@ -667,7 +677,9 @@ def select(
] = tuple(field.ast_field for field in added_fields)

# Update the current selection list with new selections
self.selection_set.selections = self.selection_set.selections + added_selections
self.selection_set = SelectionSetNode(
selections=self.selection_set.selections + added_selections
)

log.debug(f"Added fields: {added_fields} in {self!r}")

Expand Down Expand Up @@ -799,7 +811,7 @@ def executable_ast(self) -> OperationDefinitionNode:
operation=OperationType(self.operation_type),
selection_set=self.selection_set,
variable_definitions=self.variable_definitions.get_ast_definitions(),
**({"name": NameNode(value=self.name)} if self.name else {}),
name=NameNode(value=self.name) if self.name else None,
directives=self.directives_ast,
)

Expand Down Expand Up @@ -857,7 +869,12 @@ def to_ast_type(self, type_: GraphQLInputType) -> TypeNode:
return ListTypeNode(type=self.to_ast_type(type_.of_type))

elif isinstance(type_, GraphQLNonNull):
return NonNullTypeNode(type=self.to_ast_type(type_.of_type))
return NonNullTypeNode(
type=cast(
Union[NamedTypeNode, ListTypeNode],
self.to_ast_type(type_.of_type),
)
)

assert isinstance(
type_, (GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType)
Expand Down Expand Up @@ -924,14 +941,14 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]:
"""
return tuple(
VariableDefinitionNode(
type=var.ast_variable_type,
type=cast(TypeNode, var.ast_variable_type),
variable=var.ast_variable_name,
default_value=(
None
if var.default_value is None
else ast_from_value(var.default_value, var.type)
),
directives=var.directives_ast,
directives=cast(Tuple[ConstDirectiveNode, ...], var.directives_ast),
)
for var in self.variables.values()
if var.type is not None # only variables used
Expand Down Expand Up @@ -1141,13 +1158,23 @@ def args(self, **kwargs: Any) -> Self:

assert self.ast_field.arguments is not None

self.ast_field.arguments = self.ast_field.arguments + tuple(
new_arguments = self.ast_field.arguments + tuple(
ArgumentNode(
name=NameNode(value=name),
value=ast_from_value(value, self._get_argument(name).type),
value=cast(
ValueNode,
ast_from_value(value, self._get_argument(name).type),
),
)
for name, value in kwargs.items()
)
self.ast_field = FieldNode(
name=self.ast_field.name,
alias=self.ast_field.alias,
arguments=new_arguments,
directives=self.ast_field.directives,
selection_set=self.ast_field.selection_set,
)

log.debug(f"Added arguments {kwargs} in field {self!r})")

Expand Down Expand Up @@ -1175,14 +1202,26 @@ def select(
"""

super().select(*fields, **fields_with_alias)
self.ast_field.selection_set = self.selection_set
self.ast_field = FieldNode(
name=self.ast_field.name,
alias=self.ast_field.alias,
arguments=self.ast_field.arguments,
directives=self.ast_field.directives,
selection_set=self.selection_set,
)

return self

def directives(self, *directives: DSLDirective) -> Self:
"""Add directives to this field."""
super().directives(*directives)
self.ast_field.directives = self.directives_ast
self.ast_field = FieldNode(
name=self.ast_field.name,
alias=self.ast_field.alias,
arguments=self.ast_field.arguments,
directives=self.directives_ast,
selection_set=self.ast_field.selection_set,
)

return self

Expand Down Expand Up @@ -1254,7 +1293,10 @@ def __init__(

log.debug(f"Creating {self!r}")

self.ast_field = InlineFragmentNode(directives=())
self.ast_field = InlineFragmentNode(
selection_set=SelectionSetNode(selections=()),
directives=(),
)

DSLSelector.__init__(self, *fields, **fields_with_alias)
DSLDirectable.__init__(self)
Expand All @@ -1266,16 +1308,22 @@ def select(
corrected typing hints
"""
super().select(*fields, **fields_with_alias)
self.ast_field.selection_set = self.selection_set
self.ast_field = InlineFragmentNode(
selection_set=self.selection_set,
type_condition=self.ast_field.type_condition,
directives=self.ast_field.directives,
)

return self

def on(self, type_condition: DSLType) -> Self:
"""Provides the GraphQL type of this inline fragment."""

self._type = type_condition._type
self.ast_field.type_condition = NamedTypeNode(
name=NameNode(value=self._type.name)
self.ast_field = InlineFragmentNode(
selection_set=self.ast_field.selection_set,
type_condition=NamedTypeNode(name=NameNode(value=self._type.name)),
directives=self.ast_field.directives,
)
return self

Expand All @@ -1285,7 +1333,11 @@ def directives(self, *directives: DSLDirective) -> Self:
Inline fragments support all directive types through auto-validation.
"""
super().directives(*directives)
self.ast_field.directives = self.directives_ast
self.ast_field = InlineFragmentNode(
selection_set=self.ast_field.selection_set,
type_condition=self.ast_field.type_condition,
directives=self.directives_ast,
)
return self

def __repr__(self) -> str:
Expand Down Expand Up @@ -1338,7 +1390,10 @@ def directives(self, *directives: DSLDirective) -> Self:
Fragment spreads support all directive types through auto-validation.
"""
super().directives(*directives)
self.ast_field.directives = self.directives_ast
self.ast_field = FragmentSpreadNode(
name=self.ast_field.name,
directives=self.directives_ast,
)
return self

def is_valid_directive(self, directive: DSLDirective) -> bool:
Expand Down Expand Up @@ -1382,7 +1437,10 @@ def name(self) -> str:
def name(self, value: str) -> None:
""":meta private:"""
if hasattr(self, "ast_field"):
self.ast_field.name.value = value
self.ast_field = FragmentSpreadNode(
name=NameNode(value=value),
directives=self.ast_field.directives,
)

def spread(self) -> DSLFragmentSpread:
"""Create a fragment spread that can have its own directives.
Expand Down Expand Up @@ -1435,6 +1493,8 @@ def executable_ast(self) -> FragmentDefinitionNode:

fragment_variable_definitions = self.variable_definitions.get_ast_definitions()

variable_definition_kwargs: Dict[str, Any]

if len(fragment_variable_definitions) == 0:
"""Fragment variable definitions are obsolete and only supported on
graphql-core if the Parser is initialized with:
Expand All @@ -1452,9 +1512,9 @@ def executable_ast(self) -> FragmentDefinitionNode:
return FragmentDefinitionNode(
type_condition=NamedTypeNode(name=NameNode(value=self._type.name)),
selection_set=self.selection_set,
**variable_definition_kwargs,
name=NameNode(value=self.name),
directives=self.directives_ast,
**variable_definition_kwargs,
)

def is_valid_directive(self, directive: DSLDirective) -> bool:
Expand Down Expand Up @@ -1516,7 +1576,7 @@ def dsl_gql(
)

document = DocumentNode(
definitions=[operation.executable_ast for operation in all_operations]
definitions=tuple(operation.executable_ast for operation in all_operations)
)

return GraphQLRequest(document)
9 changes: 5 additions & 4 deletions gql/utilities/get_introspection_query_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,13 @@ def get_introspection_query_ast(

if type_recursion_level >= 1:
current_field = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
fragment_TypeRef.select(current_field)

for _ in repeat(None, type_recursion_level - 1):
new_oftype = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
current_field.select(new_oftype)
current_field = new_oftype
parent_field = ds.__Type.ofType.select(ds.__Type.kind, ds.__Type.name)
parent_field.select(current_field)
current_field = parent_field

fragment_TypeRef.select(current_field)

query = DSLQuery(schema)

Expand Down
5 changes: 4 additions & 1 deletion gql/utilities/node_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _node_tree_recursive(

results = []

if hasattr(obj, "__slots__"):
if hasattr(obj, "__slots__") or isinstance(obj, Node):

results.append(" " * indent + f"{type(obj).__name__}")

Expand Down Expand Up @@ -89,4 +89,7 @@ def node_tree(
# We are ignoring block attributes by default (in StringValueNode)
ignored_keys.append("block")

# Ignore new field added in graphql-core 3.3.0a12 to keep output compatible
ignored_keys.append("nullability_assertion")

return _node_tree_recursive(obj, ignored_keys=ignored_keys)
8 changes: 4 additions & 4 deletions gql/utilities/parse_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def enter_operation_definition(
if not hasattr(node.name, "value"):
return REMOVE # pragma: no cover

node.name = cast(NameNode, node.name)
name = cast(NameNode, node.name)

if node.name.value != self.operation_name:
log.debug(f"SKIPPING operation {node.name.value}")
if name.value != self.operation_name:
log.debug(f"SKIPPING operation {name.value}")
return REMOVE

return IDLE
Expand Down Expand Up @@ -238,7 +238,7 @@ def enter_field(
assert isinstance(selection_set_node, SelectionSetNode)

# Keep only the current node in a new selection set node
new_node = SelectionSetNode(selections=[node])
new_node = SelectionSetNode(selections=(node,))

for item in result_value:

Expand Down
Loading