Controlled Walking of Nested AST Nodes

Hi all! I’m one of the maintainers for flake8-annotations, a plugin for detecting the absence of PEP 3107-style function annotations and PEP 484-style type comments.

I thought I had handled nested definitions well, at least according to the test cases I had initially come up with, but there have been a few bugs exposed in the flow of parsing the AST that I’m having trouble coming up with a good way to deal with. The current approach being taken for nesting is to use a generic_visit() on the node to visit any nested functions. However, from the bugs identified I believe this approach is too generic & I need to derive a method for restricting the depth of this generic visit to one level rather than to the full depth.

For reference, the full parser can be found here (latest dev branch), but I’ll trim down into self-contained examples below. To note for later, beyond the ast.NodeVisitor subclasses there are also Function and Argument classes that are used to represent function definitions and arguments & contain the information needed for classification of potential linting errors.

Recently, we introduced an opinionated flag to suppress missing return annotations if the function implicitly or explicitly returns None. This was accomplished by adding in an ast.NodeVisitor subclass that contains only a visit_Return method. This subclass is invoked as a checker when a Function instance is being created from a FunctionDef or AsyncFunctionDef node.

Here’s the visitor:

ReturnVisitor class
class ReturnVisitor(ast.NodeVisitor):

    def __init__(self):
        self.has_only_none_returns = True

    def visit_Return(self, node: ast.Return) -> None:
        if node.value is not None:
            if isinstance(node.value, (ast.Constant, ast.NameConstant)):
                if node.value.value is None:
                    return

            self.has_only_none_returns = False

And some sample code:

def foo():
    def bar() -> bool:
        return True
    bar()

With the opionated flag set, the linter should not yield any linting errors. However, what ends up actually happening is that when constructing the Function instance for foo, bar's return statement is being visited and treated as a return for foo:

Full Example
import ast
from textwrap import dedent


class ReturnVisitor(ast.NodeVisitor):

    def __init__(self):
        self.has_only_none_returns = True

    def visit_Return(self, node: ast.Return) -> None:
        print(f"Node dump: {ast.dump(node)}")
        if node.value is not None:
            if isinstance(node.value, (ast.Constant, ast.NameConstant)):
                if node.value.value is None:
                    return

            self.has_only_none_returns = False


src = dedent(
    """\
    def foo():
        def bar() -> bool:
            return True
        bar()
    """
)

tree = ast.parse(src)
foo_visitor = ReturnVisitor()
foo_visitor.visit(tree)
print(f"foo has only none returns? {foo_visitor.has_only_none_returns}")

Which prints out:

Node dump: Return(value=Constant(value=True, kind=None))
foo has only none returns? False

I’m running into what I believe is the same issue when parsing class methods, which are flagged explicitly as there are error codes specific to class methods. The function node is passed this flag by the visit_ClassDef method of the ast.NodeVisitor subclass being used for the general function parsing:

FunctionVisitor class
class FunctionVisitor(ast.NodeVisitor):

    def __init__(self, lines: List[str]):
        self.lines = lines
        self.function_definitions = []

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        self.function_definitions.append(Function.from_function_node(node, self.lines))
        self.generic_visit(node)  # Walk through any nested functions

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
        self.function_definitions.append(Function.from_function_node(node, self.lines))
        self.generic_visit(node)  # Walk through any nested functions

    def visit_ClassDef(self, node: ast.ClassDef) -> None:
        method_nodes = [
            child_node
            for child_node in node.body
            if isinstance(child_node, (ast.FunctionDef, ast.AsyncFunctionDef))
        ]
        self.function_definitions.extend(
            [
                Function.from_function_node(method_node, self.lines, is_class_method=True)
                for method_node in method_nodes
            ]
        )

        # Use ast.NodeVisitor.generic_visit to start down the nested method chain
        for sub_node in node.body:
            self.generic_visit(sub_node)

While this handles nested functions of its own class methods, it fails to properly flag methods of nested classes. For example:

class Foo:
    class Bar:
        def bar_method(self):
            pass
    def foo_method(self):
        pass

Here, bar_method will not be properly classified as a class method of Bar, which is being visited by both the more generic visit_FunctionDef and by the iteration over child nodes by visit_ClassDef (in order to pass clas method flag).

(Sorry, this one is quite long, I’ve trimmed out what I can)

Full Example
import ast
from textwrap import dedent
from typing import Any, List, Union

AST_FUNCTION_TYPES = Union[ast.FunctionDef, ast.AsyncFunctionDef]
AST_ARG_TYPES = ("args",)  # Simplified for example


class Argument:

    def __init__(
        self,
        argname: str,
        lineno: int,
        col_offset: int,
        annotation_type: str,  # Snipped for example
        has_type_annotation: bool = False,
        has_3107_annotation: bool = False,
    ):
        self.argname = argname
        self.lineno = lineno
        self.col_offset = col_offset
        self.annotation_type = annotation_type
        self.has_type_annotation = has_type_annotation
        self.has_3107_annotation = has_3107_annotation

    def __repr__(self) -> str:
        """Format the Argument object into its "official" representation."""
        return (
            f"Argument(argname={self.argname!r}, lineno={self.lineno}, col_offset={self.col_offset}, "
            f"annotation_type={self.annotation_type}, has_type_annotation={self.has_type_annotation}, "
            f"has_3107_annotation={self.has_3107_annotation})"
        )

    @classmethod
    def from_arg_node(cls, node: ast.arguments, annotation_type_name: str):
        annotation_type = "Snipped"
        new_arg = cls(node.arg, node.lineno, node.col_offset, annotation_type)

        new_arg.has_type_annotation = False
        if node.annotation:
            new_arg.has_type_annotation = True
            new_arg.has_3107_annotation = True

        return new_arg


class Function:

    def __init__(
        self,
        name: str,
        lineno: int,
        col_offset: int,
        function_type: str = "PUBLIC",  # Simplified for example
        is_class_method: bool = False,
        class_decorator_type: Any = None,  # Simplified for example
        is_return_annotated: bool = False,
        args: List[Argument] = None,
    ):
        self.name = name
        self.lineno = lineno
        self.col_offset = col_offset
        self.function_type = function_type
        self.is_class_method = is_class_method
        self.class_decorator_type = class_decorator_type
        self.is_return_annotated = is_return_annotated
        self.args = args

    def is_fully_annotated(self) -> bool:
        return all(arg.has_type_annotation for arg in self.args)

    def get_missed_annotations(self) -> List:
        return [arg for arg in self.args if not arg.has_type_annotation]

    def get_annotated_arguments(self) -> List:
        return [arg for arg in self.args if arg.has_type_annotation]

    def __repr__(self) -> str:
        return (
            f"Function(name={self.name!r}, lineno={self.lineno}, col_offset={self.col_offset}, "
            f"function_type={self.function_type}, is_class_method={self.is_class_method}, "
            f"class_decorator_type={self.class_decorator_type}, "
            f"is_return_annotated={self.is_return_annotated})"
        )

    @classmethod
    def from_function_node(cls, node: AST_FUNCTION_TYPES, lines: List[str], **kwargs):
        # Extract function types from function name
        kwargs["function_type"] = "PUBLIC"  # Simplified for example
        if kwargs.get("is_class_method", False):
            kwargs["class_decorator_type"] = None  # Simplified for example

        new_function = cls(node.name, node.lineno, node.col_offset, **kwargs)

        # Iterate over arguments by type & add
        new_function.args = []
        for arg_type in AST_ARG_TYPES:
            args = node.args.__getattribute__(arg_type)
            if args:
                if not isinstance(args, list):
                    args = [args]

                new_function.args.extend(
                    [Argument.from_arg_node(arg, arg_type.upper()) for arg in args]
                )

        # Create an Argument object for the return hint
        # Get the line number from the line before where the body of the function starts to account
        # for the presence of decorators
        def_end_lineno = node.body[0].lineno - 1
        while True:
            # To account for multiline docstrings, rewind through the lines until we find the line
            # containing the :
            # Use str.rfind() to account for annotations on the same line, definition closure should
            # be the last : on the line
            colon_loc = lines[def_end_lineno - 1].rfind(":")
            if colon_loc == -1:
                def_end_lineno -= 1
            else:
                # Lineno is 1-indexed, the line string is 0-indexed
                def_end_col_offset = colon_loc + 1
                break

        return_arg = Argument("return", def_end_lineno, def_end_col_offset, "RETURN")
        if node.returns:
            return_arg.has_type_annotation = True
            return_arg.has_3107_annotation = True
            new_function.is_return_annotated = True

        new_function.args.append(return_arg)

        return new_function


class FunctionVisitor(ast.NodeVisitor):

    def __init__(self, lines: List[str]):
        self.lines = lines
        self.function_definitions = []

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        print("visit_FunctionDef visited")
        self.function_definitions.append(Function.from_function_node(node, self.lines))
        self.generic_visit(node)  # Walk through any nested functions

    def visit_ClassDef(self, node: ast.ClassDef) -> None:
        print("visit_ClassDef visited")
        method_nodes = [
            child_node
            for child_node in node.body
            if isinstance(child_node, (ast.FunctionDef, ast.AsyncFunctionDef))
        ]
        self.function_definitions.extend(
            [
                Function.from_function_node(method_node, self.lines, is_class_method=True)
                for method_node in method_nodes
            ]
        )

        # Use ast.NodeVisitor.generic_visit to start down the nested method chain
        for sub_node in node.body:
            self.generic_visit(sub_node)


src = dedent(
    """\
    class Foo:
        class Bar:
            def bar_method(self):
                pass
        def foo_method(self):
            pass
    """
)

lines = src.splitlines()
tree = ast.parse(src)
visitor = FunctionVisitor(lines)
visitor.visit(tree)
defs = "\n".join(repr(fun) for fun in visitor.function_definitions)
print(f"\nFunction definitions:\n{defs}")

Which prints out:

visit_ClassDef visited
visit_FunctionDef visited

Function definitions:
Function(name='foo_method', lineno=5, col_offset=4, function_type=PUBLIC, is_class_method=True, class_decorator_type=None, is_return_annotated=False)
Function(name='bar_method', lineno=3, col_offset=8, function_type=PUBLIC, is_class_method=False, class_decorator_type=None, is_return_annotated=False)

If you’ve made it this far, thank you! :slight_smile:

To reiterate from above: The current approach being taken for nesting is to use a generic_visit() on the node to visit any nested functions. However, from the bugs identified I believe this approach is too generic & I need to derive a method for restricting the depth of this generic visit to one level rather than to the full depth. I’m struggling with trying to grok a method for doing this that doesn’t turn into a total monstrosity, and would love any pointers!

If I’ve missed anything in my explanation or the examples, please let me know!

1 Like

So the general problem looks like you don’t handle context switches on sub-functions.

def x(): # <= context switch
    def y(): # <= context switch
        return # <= this now belongs to the y context
    y() # <= this belongs to the x context

What I can suggest is keeping a list of active contexts in FunctionDef and AsyncFunctionDef (also if you are looking for a more generalized algorithm, I can suggest inspectortiger’s context system)

class ReturnVisitor(ast.NodeVisitor):

    def __init__(self):
        self.context = []
        self.bad_functions = set()

    def visit_Return(self, node: ast.Return) -> None:
        print(f"Node dump: {ast.dump(node)}")
        if node.value is not None:
            if isinstance(node.value, (ast.Constant, ast.NameConstant)):
                if node.value.value is None:
                    return

            self.bad_functions.add(self.context[-1].name)

    def change_context(self, node):
        self.context.append(node)
        self.generic_visit(node)
        self.context.pop()

    visit_FunctionDef = change_context
    visit_AsyncFunctionDef = change_context
print(foo_visitor.bad_functions)
{'bar'}
1 Like

Thanks for the tip!

I’ve made a new example based on the feedback:

Example
import ast
from textwrap import dedent
from typing import Union

AST_FUNCTION_TYPES = Union[ast.FunctionDef, ast.AsyncFunctionDef]


class ReturnVisitor(ast.NodeVisitor):

    def __init__(self):
        self.has_only_none_returns = True
        self._context = []

    def visit_Return(self, node: ast.Return) -> None:
        print(f"Return visited: {ast.dump(node)}\n")
        if node.value is not None:
            # In the event of an explicit `None` return (`return None`), the node body will be an
            # instance of either `ast.Constant` (3.8+) or `ast.NameConstant`, which we need to check
            # to see if it's actually `None`
            if isinstance(node.value, (ast.Constant, ast.NameConstant)):
                if node.value.value is None:
                    return

            self.has_only_none_returns = False

    def switch_context(self, node: AST_FUNCTION_TYPES) -> None:
        self._context.append(node)
        print(f"Generic visiting: {ast.dump(node)}\n")
        self.generic_visit(node)
        self._context.pop()

    visit_FunctionDef = switch_context
    visit_AsyncFunctionDef = switch_context


src = dedent(
    """\
    def foo():
        def bar() -> bool:
            return True
        bar()
    """
)

tree = ast.parse(src)
foo_visitor = ReturnVisitor()
foo_visitor.visit(tree)
print(f"foo has only none returns? {foo_visitor.has_only_none_returns}")

Which is still yielding the incorrect result, though I understand why:

Generic visiting: FunctionDef(name='foo', args=arguments(posonlyargs=[], args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[FunctionDef(name='bar', args=arguments(posonlyargs=[], args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[Return(value=Constant(value=True, kind=None))], decorator_list=[], returns=Name(id='bool', ctx=Load()), type_comment=None), Expr(value=Call(func=Name(id='bar', ctx=Load()), args=[], keywords=[]))], decorator_list=[], returns=None, type_comment=None)
Generic visiting: FunctionDef(name='bar', args=arguments(posonlyargs=[], args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[Return(value=Constant(value=True, kind=None))], decorator_list=[], returns=Name(id='bool', ctx=Load()), type_comment=None)
Return visited: Return(value=Constant(value=True, kind=None))

foo has only none returns? False

I believe I have 2 approaches for a solution here:

  1. Adjust ReturnVisitor to be essentially identical to your example, which I believe would store which function’s return is being visited, then compare that upstream where ReturnVisitor is invoked to see if the top-level node (foo, in this case) is contained in the bad_functions set.
  2. Identify an alternative method to generic_visit, so that everything in the body of the top-level node (again, foo) is ignored except for any Return nodes that are direct children of foo. This was my original thinking, but maybe I’m heading into XY problem territory here?

I haven’t yet explored the class approach, which I’ll look at after I’ve squared away the ReturnVisit case.

Edit: I went ahead and implemented the first case and it works great! I used a property method to check if the top-level node is in the set of non-None-returning nodes, so the change almost completely transparent upstream. On to the nested classes!

1 Like

Concur with @isidentical.

Well, it turns out context switching makes the FunctionVisitor even simpler!

Full example
import ast
from textwrap import dedent
from typing import Any, List, Union

AST_FUNCTION_TYPES = Union[ast.FunctionDef, ast.AsyncFunctionDef]
AST_ARG_TYPES = ("args",)  # Simplified for example


class Argument:

    def __init__(
        self,
        argname: str,
        lineno: int,
        col_offset: int,
        annotation_type: str,  # Snipped for example
        has_type_annotation: bool = False,
        has_3107_annotation: bool = False,
    ):
        self.argname = argname
        self.lineno = lineno
        self.col_offset = col_offset
        self.annotation_type = annotation_type
        self.has_type_annotation = has_type_annotation
        self.has_3107_annotation = has_3107_annotation

    def __repr__(self) -> str:
        """Format the Argument object into its "official" representation."""
        return (
            f"Argument(argname={self.argname!r}, lineno={self.lineno}, col_offset={self.col_offset}, "
            f"annotation_type={self.annotation_type}, has_type_annotation={self.has_type_annotation}, "
            f"has_3107_annotation={self.has_3107_annotation})"
        )

    @classmethod
    def from_arg_node(cls, node: ast.arguments, annotation_type_name: str):
        annotation_type = "Snipped"
        new_arg = cls(node.arg, node.lineno, node.col_offset, annotation_type)

        new_arg.has_type_annotation = False
        if node.annotation:
            new_arg.has_type_annotation = True
            new_arg.has_3107_annotation = True

        return new_arg


class Function:

    def __init__(
        self,
        name: str,
        lineno: int,
        col_offset: int,
        function_type: str = "PUBLIC",  # Simplified for example
        is_class_method: bool = False,
        class_decorator_type: Any = None,  # Simplified for example
        is_return_annotated: bool = False,
        args: List[Argument] = None,
    ):
        self.name = name
        self.lineno = lineno
        self.col_offset = col_offset
        self.function_type = function_type
        self.is_class_method = is_class_method
        self.class_decorator_type = class_decorator_type
        self.is_return_annotated = is_return_annotated
        self.args = args

    def is_fully_annotated(self) -> bool:
        return all(arg.has_type_annotation for arg in self.args)

    def get_missed_annotations(self) -> List:
        return [arg for arg in self.args if not arg.has_type_annotation]

    def get_annotated_arguments(self) -> List:
        return [arg for arg in self.args if arg.has_type_annotation]

    def __repr__(self) -> str:
        return (
            f"Function(name={self.name!r}, lineno={self.lineno}, col_offset={self.col_offset}, "
            f"function_type={self.function_type}, is_class_method={self.is_class_method}, "
            f"class_decorator_type={self.class_decorator_type}, "
            f"is_return_annotated={self.is_return_annotated})"
        )

    @classmethod
    def from_function_node(cls, node: AST_FUNCTION_TYPES, lines: List[str], **kwargs):
        # Extract function types from function name
        kwargs["function_type"] = "PUBLIC"  # Simplified for example
        if kwargs.get("is_class_method", False):
            kwargs["class_decorator_type"] = None  # Simplified for example

        new_function = cls(node.name, node.lineno, node.col_offset, **kwargs)

        # Iterate over arguments by type & add
        new_function.args = []
        for arg_type in AST_ARG_TYPES:
            args = node.args.__getattribute__(arg_type)
            if args:
                if not isinstance(args, list):
                    args = [args]

                new_function.args.extend(
                    [Argument.from_arg_node(arg, arg_type.upper()) for arg in args]
                )

        # Create an Argument object for the return hint
        # Get the line number from the line before where the body of the function starts to account
        # for the presence of decorators
        def_end_lineno = node.body[0].lineno - 1
        while True:
            # To account for multiline docstrings, rewind through the lines until we find the line
            # containing the :
            # Use str.rfind() to account for annotations on the same line, definition closure should
            # be the last : on the line
            colon_loc = lines[def_end_lineno - 1].rfind(":")
            if colon_loc == -1:
                def_end_lineno -= 1
            else:
                # Lineno is 1-indexed, the line string is 0-indexed
                def_end_col_offset = colon_loc + 1
                break

        return_arg = Argument("return", def_end_lineno, def_end_col_offset, "RETURN")
        if node.returns:
            return_arg.has_type_annotation = True
            return_arg.has_3107_annotation = True
            new_function.is_return_annotated = True

        new_function.args.append(return_arg)

        return new_function


class FunctionVisitor(ast.NodeVisitor):

    def __init__(self, lines: List[str]):
        self.lines = lines
        self.function_definitions = []
        self._context = []

    def switch_context(self, node):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
            if self._context and isinstance(self._context[-1], ast.ClassDef):
                self.function_definitions.append(
                    Function.from_function_node(node, self.lines, is_class_method=True)
                )
            else:
                self.function_definitions.append(Function.from_function_node(node, self.lines))
                
        self._context.append(node)
        self.generic_visit(node)
        self._context.pop()

    visit_FunctionDef = switch_context
    visit_AsyncFunctionDef = switch_context
    visit_ClassDef = switch_context


src = dedent(
    """\
    class Foo:
        class Bar:
            def bar_method(self):
                pass
        def foo_method(self):
            pass
    """
)

lines = src.splitlines()
tree = ast.parse(src)
visitor = FunctionVisitor(lines)
visitor.visit(tree)
defs = "\n".join(repr(fun) for fun in visitor.function_definitions)
print(f"\nFunction definitions:\n{defs}")

Now correctly prints:

Function definitions:
Function(name='bar_method', lineno=3, col_offset=8, function_type=PUBLIC, is_class_method=True, class_decorator_type=None, is_return_annotated=False)
Function(name='foo_method', lineno=5, col_offset=4, function_type=PUBLIC, is_class_method=True, class_decorator_type=None, is_return_annotated=False)

Thanks again everyone :slight_smile:

1 Like