Hi all! I’m one of the maintainers for flake8-annotations
, a plugin for detecting the absence of PEP 3107-style function annotations and PEP 484-style type comments.
I thought I had handled nested definitions well, at least according to the test cases I had initially come up with, but there have been a few bugs exposed in the flow of parsing the AST that I’m having trouble coming up with a good way to deal with. The current approach being taken for nesting is to use a generic_visit()
on the node to visit any nested functions. However, from the bugs identified I believe this approach is too generic & I need to derive a method for restricting the depth of this generic visit to one level rather than to the full depth.
For reference, the full parser can be found here (latest dev branch), but I’ll trim down into self-contained examples below. To note for later, beyond the ast.NodeVisitor
subclasses there are also Function
and Argument
classes that are used to represent function definitions and arguments & contain the information needed for classification of potential linting errors.
Recently, we introduced an opinionated flag to suppress missing return
annotations if the function implicitly or explicitly returns None
. This was accomplished by adding in an ast.NodeVisitor
subclass that contains only a visit_Return
method. This subclass is invoked as a checker when a Function
instance is being created from a FunctionDef
or AsyncFunctionDef
node.
Here’s the visitor:
ReturnVisitor class
class ReturnVisitor(ast.NodeVisitor):
def __init__(self):
self.has_only_none_returns = True
def visit_Return(self, node: ast.Return) -> None:
if node.value is not None:
if isinstance(node.value, (ast.Constant, ast.NameConstant)):
if node.value.value is None:
return
self.has_only_none_returns = False
And some sample code:
def foo():
def bar() -> bool:
return True
bar()
With the opionated flag set, the linter should not yield any linting errors. However, what ends up actually happening is that when constructing the Function
instance for foo
, bar
's return
statement is being visited and treated as a return for foo
:
Full Example
import ast
from textwrap import dedent
class ReturnVisitor(ast.NodeVisitor):
def __init__(self):
self.has_only_none_returns = True
def visit_Return(self, node: ast.Return) -> None:
print(f"Node dump: {ast.dump(node)}")
if node.value is not None:
if isinstance(node.value, (ast.Constant, ast.NameConstant)):
if node.value.value is None:
return
self.has_only_none_returns = False
src = dedent(
"""\
def foo():
def bar() -> bool:
return True
bar()
"""
)
tree = ast.parse(src)
foo_visitor = ReturnVisitor()
foo_visitor.visit(tree)
print(f"foo has only none returns? {foo_visitor.has_only_none_returns}")
Which prints out:
Node dump: Return(value=Constant(value=True, kind=None))
foo has only none returns? False
I’m running into what I believe is the same issue when parsing class methods, which are flagged explicitly as there are error codes specific to class methods. The function node is passed this flag by the visit_ClassDef
method of the ast.NodeVisitor
subclass being used for the general function parsing:
FunctionVisitor class
class FunctionVisitor(ast.NodeVisitor):
def __init__(self, lines: List[str]):
self.lines = lines
self.function_definitions = []
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.function_definitions.append(Function.from_function_node(node, self.lines))
self.generic_visit(node) # Walk through any nested functions
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.function_definitions.append(Function.from_function_node(node, self.lines))
self.generic_visit(node) # Walk through any nested functions
def visit_ClassDef(self, node: ast.ClassDef) -> None:
method_nodes = [
child_node
for child_node in node.body
if isinstance(child_node, (ast.FunctionDef, ast.AsyncFunctionDef))
]
self.function_definitions.extend(
[
Function.from_function_node(method_node, self.lines, is_class_method=True)
for method_node in method_nodes
]
)
# Use ast.NodeVisitor.generic_visit to start down the nested method chain
for sub_node in node.body:
self.generic_visit(sub_node)
While this handles nested functions of its own class methods, it fails to properly flag methods of nested classes. For example:
class Foo:
class Bar:
def bar_method(self):
pass
def foo_method(self):
pass
Here, bar_method
will not be properly classified as a class method of Bar
, which is being visited by both the more generic visit_FunctionDef
and by the iteration over child nodes by visit_ClassDef
(in order to pass clas method flag).
(Sorry, this one is quite long, I’ve trimmed out what I can)
Full Example
import ast
from textwrap import dedent
from typing import Any, List, Union
AST_FUNCTION_TYPES = Union[ast.FunctionDef, ast.AsyncFunctionDef]
AST_ARG_TYPES = ("args",) # Simplified for example
class Argument:
def __init__(
self,
argname: str,
lineno: int,
col_offset: int,
annotation_type: str, # Snipped for example
has_type_annotation: bool = False,
has_3107_annotation: bool = False,
):
self.argname = argname
self.lineno = lineno
self.col_offset = col_offset
self.annotation_type = annotation_type
self.has_type_annotation = has_type_annotation
self.has_3107_annotation = has_3107_annotation
def __repr__(self) -> str:
"""Format the Argument object into its "official" representation."""
return (
f"Argument(argname={self.argname!r}, lineno={self.lineno}, col_offset={self.col_offset}, "
f"annotation_type={self.annotation_type}, has_type_annotation={self.has_type_annotation}, "
f"has_3107_annotation={self.has_3107_annotation})"
)
@classmethod
def from_arg_node(cls, node: ast.arguments, annotation_type_name: str):
annotation_type = "Snipped"
new_arg = cls(node.arg, node.lineno, node.col_offset, annotation_type)
new_arg.has_type_annotation = False
if node.annotation:
new_arg.has_type_annotation = True
new_arg.has_3107_annotation = True
return new_arg
class Function:
def __init__(
self,
name: str,
lineno: int,
col_offset: int,
function_type: str = "PUBLIC", # Simplified for example
is_class_method: bool = False,
class_decorator_type: Any = None, # Simplified for example
is_return_annotated: bool = False,
args: List[Argument] = None,
):
self.name = name
self.lineno = lineno
self.col_offset = col_offset
self.function_type = function_type
self.is_class_method = is_class_method
self.class_decorator_type = class_decorator_type
self.is_return_annotated = is_return_annotated
self.args = args
def is_fully_annotated(self) -> bool:
return all(arg.has_type_annotation for arg in self.args)
def get_missed_annotations(self) -> List:
return [arg for arg in self.args if not arg.has_type_annotation]
def get_annotated_arguments(self) -> List:
return [arg for arg in self.args if arg.has_type_annotation]
def __repr__(self) -> str:
return (
f"Function(name={self.name!r}, lineno={self.lineno}, col_offset={self.col_offset}, "
f"function_type={self.function_type}, is_class_method={self.is_class_method}, "
f"class_decorator_type={self.class_decorator_type}, "
f"is_return_annotated={self.is_return_annotated})"
)
@classmethod
def from_function_node(cls, node: AST_FUNCTION_TYPES, lines: List[str], **kwargs):
# Extract function types from function name
kwargs["function_type"] = "PUBLIC" # Simplified for example
if kwargs.get("is_class_method", False):
kwargs["class_decorator_type"] = None # Simplified for example
new_function = cls(node.name, node.lineno, node.col_offset, **kwargs)
# Iterate over arguments by type & add
new_function.args = []
for arg_type in AST_ARG_TYPES:
args = node.args.__getattribute__(arg_type)
if args:
if not isinstance(args, list):
args = [args]
new_function.args.extend(
[Argument.from_arg_node(arg, arg_type.upper()) for arg in args]
)
# Create an Argument object for the return hint
# Get the line number from the line before where the body of the function starts to account
# for the presence of decorators
def_end_lineno = node.body[0].lineno - 1
while True:
# To account for multiline docstrings, rewind through the lines until we find the line
# containing the :
# Use str.rfind() to account for annotations on the same line, definition closure should
# be the last : on the line
colon_loc = lines[def_end_lineno - 1].rfind(":")
if colon_loc == -1:
def_end_lineno -= 1
else:
# Lineno is 1-indexed, the line string is 0-indexed
def_end_col_offset = colon_loc + 1
break
return_arg = Argument("return", def_end_lineno, def_end_col_offset, "RETURN")
if node.returns:
return_arg.has_type_annotation = True
return_arg.has_3107_annotation = True
new_function.is_return_annotated = True
new_function.args.append(return_arg)
return new_function
class FunctionVisitor(ast.NodeVisitor):
def __init__(self, lines: List[str]):
self.lines = lines
self.function_definitions = []
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
print("visit_FunctionDef visited")
self.function_definitions.append(Function.from_function_node(node, self.lines))
self.generic_visit(node) # Walk through any nested functions
def visit_ClassDef(self, node: ast.ClassDef) -> None:
print("visit_ClassDef visited")
method_nodes = [
child_node
for child_node in node.body
if isinstance(child_node, (ast.FunctionDef, ast.AsyncFunctionDef))
]
self.function_definitions.extend(
[
Function.from_function_node(method_node, self.lines, is_class_method=True)
for method_node in method_nodes
]
)
# Use ast.NodeVisitor.generic_visit to start down the nested method chain
for sub_node in node.body:
self.generic_visit(sub_node)
src = dedent(
"""\
class Foo:
class Bar:
def bar_method(self):
pass
def foo_method(self):
pass
"""
)
lines = src.splitlines()
tree = ast.parse(src)
visitor = FunctionVisitor(lines)
visitor.visit(tree)
defs = "\n".join(repr(fun) for fun in visitor.function_definitions)
print(f"\nFunction definitions:\n{defs}")
Which prints out:
visit_ClassDef visited
visit_FunctionDef visited
Function definitions:
Function(name='foo_method', lineno=5, col_offset=4, function_type=PUBLIC, is_class_method=True, class_decorator_type=None, is_return_annotated=False)
Function(name='bar_method', lineno=3, col_offset=8, function_type=PUBLIC, is_class_method=False, class_decorator_type=None, is_return_annotated=False)
If you’ve made it this far, thank you!
To reiterate from above: The current approach being taken for nesting is to use a generic_visit()
on the node to visit any nested functions. However, from the bugs identified I believe this approach is too generic & I need to derive a method for restricting the depth of this generic visit to one level rather than to the full depth. I’m struggling with trying to grok a method for doing this that doesn’t turn into a total monstrosity, and would love any pointers!
If I’ve missed anything in my explanation or the examples, please let me know!