updates
This commit is contained in:
@@ -1,13 +1,5 @@
|
||||
"""Rewrite assertion AST to produce nice error messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
import errno
|
||||
import functools
|
||||
import importlib.abc
|
||||
@@ -17,46 +9,53 @@ import io
|
||||
import itertools
|
||||
import marshal
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pathlib import PurePath
|
||||
import struct
|
||||
import sys
|
||||
import tokenize
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from pathlib import PurePath
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import IO
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from importlib.resources.abc import TraversableResources
|
||||
else:
|
||||
from importlib.abc import TraversableResources
|
||||
if sys.version_info < (3, 11):
|
||||
from importlib.readers import FileReader
|
||||
else:
|
||||
from importlib.resources.readers import FileReader
|
||||
|
||||
from typing import Union
|
||||
|
||||
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
|
||||
from _pytest._io.saferepr import saferepr
|
||||
from _pytest._io.saferepr import saferepr_unlimited
|
||||
from _pytest._version import version
|
||||
from _pytest.assertion import util
|
||||
from _pytest.assertion.util import ( # noqa: F401
|
||||
format_explanation as _format_explanation,
|
||||
)
|
||||
from _pytest.config import Config
|
||||
from _pytest.fixtures import FixtureFunctionDefinition
|
||||
from _pytest.main import Session
|
||||
from _pytest.pathlib import absolutepath
|
||||
from _pytest.pathlib import fnmatch_ex
|
||||
from _pytest.stash import StashKey
|
||||
|
||||
|
||||
# fmt: off
|
||||
from _pytest.assertion.util import format_explanation as _format_explanation # noqa:F401, isort:skip
|
||||
# fmt:on
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest.assertion import AssertionState
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
namedExpr = ast.NamedExpr
|
||||
astNameConstant = ast.Constant
|
||||
astStr = ast.Constant
|
||||
astNum = ast.Constant
|
||||
else:
|
||||
namedExpr = ast.Expr
|
||||
astNameConstant = ast.NameConstant
|
||||
astStr = ast.Str
|
||||
astNum = ast.Num
|
||||
|
||||
|
||||
class Sentinel:
|
||||
pass
|
||||
@@ -66,7 +65,7 @@ assertstate_key = StashKey["AssertionState"]()
|
||||
|
||||
# pytest caches rewritten pycs in pycache dirs
|
||||
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
|
||||
PYC_EXT = ".py" + ((__debug__ and "c") or "o")
|
||||
PYC_EXT = ".py" + (__debug__ and "c" or "o")
|
||||
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
|
||||
|
||||
# Special marker that denotes we have just left a scope definition
|
||||
@@ -82,17 +81,17 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
self.fnpats = config.getini("python_files")
|
||||
except ValueError:
|
||||
self.fnpats = ["test_*.py", "*_test.py"]
|
||||
self.session: Session | None = None
|
||||
self._rewritten_names: dict[str, Path] = {}
|
||||
self._must_rewrite: set[str] = set()
|
||||
self.session: Optional[Session] = None
|
||||
self._rewritten_names: Dict[str, Path] = {}
|
||||
self._must_rewrite: Set[str] = set()
|
||||
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
|
||||
# which might result in infinite recursion (#3506)
|
||||
self._writing_pyc = False
|
||||
self._basenames_to_check_rewrite = {"conftest"}
|
||||
self._marked_for_rewrite_cache: dict[str, bool] = {}
|
||||
self._marked_for_rewrite_cache: Dict[str, bool] = {}
|
||||
self._session_paths_checked = False
|
||||
|
||||
def set_session(self, session: Session | None) -> None:
|
||||
def set_session(self, session: Optional[Session]) -> None:
|
||||
self.session = session
|
||||
self._session_paths_checked = False
|
||||
|
||||
@@ -102,28 +101,18 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
def find_spec(
|
||||
self,
|
||||
name: str,
|
||||
path: Sequence[str | bytes] | None = None,
|
||||
target: types.ModuleType | None = None,
|
||||
) -> importlib.machinery.ModuleSpec | None:
|
||||
path: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
target: Optional[types.ModuleType] = None,
|
||||
) -> Optional[importlib.machinery.ModuleSpec]:
|
||||
if self._writing_pyc:
|
||||
return None
|
||||
state = self.config.stash[assertstate_key]
|
||||
if self._early_rewrite_bailout(name, state):
|
||||
return None
|
||||
state.trace(f"find_module called for: {name}")
|
||||
state.trace("find_module called for: %s" % name)
|
||||
|
||||
# Type ignored because mypy is confused about the `self` binding here.
|
||||
spec = self._find_spec(name, path) # type: ignore
|
||||
|
||||
if spec is None and path is not None:
|
||||
# With --import-mode=importlib, PathFinder cannot find spec without modifying `sys.path`,
|
||||
# causing inability to assert rewriting (#12659).
|
||||
# At this point, try using the file path to find the module spec.
|
||||
for _path_str in path:
|
||||
spec = importlib.util.spec_from_file_location(name, _path_str)
|
||||
if spec is not None:
|
||||
break
|
||||
|
||||
if (
|
||||
# the import machinery could not find a file to import
|
||||
spec is None
|
||||
@@ -151,7 +140,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
|
||||
def create_module(
|
||||
self, spec: importlib.machinery.ModuleSpec
|
||||
) -> types.ModuleType | None:
|
||||
) -> Optional[types.ModuleType]:
|
||||
return None # default behaviour is fine
|
||||
|
||||
def exec_module(self, module: types.ModuleType) -> None:
|
||||
@@ -196,7 +185,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
state.trace(f"found cached rewritten pyc for {fn}")
|
||||
exec(co, module.__dict__)
|
||||
|
||||
def _early_rewrite_bailout(self, name: str, state: AssertionState) -> bool:
|
||||
def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
|
||||
"""A fast way to get out of rewriting modules.
|
||||
|
||||
Profiling has shown that the call to PathFinder.find_spec (inside of
|
||||
@@ -235,7 +224,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
state.trace(f"early skip of rewriting module: {name}")
|
||||
return True
|
||||
|
||||
def _should_rewrite(self, name: str, fn: str, state: AssertionState) -> bool:
|
||||
def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
|
||||
# always rewrite conftest files
|
||||
if os.path.basename(fn) == "conftest.py":
|
||||
state.trace(f"rewriting conftest file: {fn!r}")
|
||||
@@ -256,7 +245,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
|
||||
return self._is_marked_for_rewrite(name, state)
|
||||
|
||||
def _is_marked_for_rewrite(self, name: str, state: AssertionState) -> bool:
|
||||
def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
|
||||
try:
|
||||
return self._marked_for_rewrite_cache[name]
|
||||
except KeyError:
|
||||
@@ -292,18 +281,31 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
|
||||
self.config.issue_config_time_warning(
|
||||
PytestAssertRewriteWarning(
|
||||
f"Module already imported so cannot be rewritten; {name}"
|
||||
"Module already imported so cannot be rewritten: %s" % name
|
||||
),
|
||||
stacklevel=5,
|
||||
)
|
||||
|
||||
def get_data(self, pathname: str | bytes) -> bytes:
|
||||
def get_data(self, pathname: Union[str, bytes]) -> bytes:
|
||||
"""Optional PEP302 get_data API."""
|
||||
with open(pathname, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def get_resource_reader(self, name: str) -> TraversableResources:
|
||||
return FileReader(types.SimpleNamespace(path=self._rewritten_names[name])) # type: ignore[arg-type]
|
||||
if sys.version_info >= (3, 10):
|
||||
if sys.version_info >= (3, 12):
|
||||
from importlib.resources.abc import TraversableResources
|
||||
else:
|
||||
from importlib.abc import TraversableResources
|
||||
|
||||
def get_resource_reader(self, name: str) -> TraversableResources: # type: ignore
|
||||
if sys.version_info < (3, 11):
|
||||
from importlib.readers import FileReader
|
||||
else:
|
||||
from importlib.resources.readers import FileReader
|
||||
|
||||
return FileReader( # type:ignore[no-any-return]
|
||||
types.SimpleNamespace(path=self._rewritten_names[name])
|
||||
)
|
||||
|
||||
|
||||
def _write_pyc_fp(
|
||||
@@ -325,7 +327,7 @@ def _write_pyc_fp(
|
||||
|
||||
|
||||
def _write_pyc(
|
||||
state: AssertionState,
|
||||
state: "AssertionState",
|
||||
co: types.CodeType,
|
||||
source_stat: os.stat_result,
|
||||
pyc: Path,
|
||||
@@ -349,7 +351,7 @@ def _write_pyc(
|
||||
return True
|
||||
|
||||
|
||||
def _rewrite_test(fn: Path, config: Config) -> tuple[os.stat_result, types.CodeType]:
|
||||
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
|
||||
"""Read and rewrite *fn* and return the code object."""
|
||||
stat = os.stat(fn)
|
||||
source = fn.read_bytes()
|
||||
@@ -362,7 +364,7 @@ def _rewrite_test(fn: Path, config: Config) -> tuple[os.stat_result, types.CodeT
|
||||
|
||||
def _read_pyc(
|
||||
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
|
||||
) -> types.CodeType | None:
|
||||
) -> Optional[types.CodeType]:
|
||||
"""Possibly read a pytest pyc containing rewritten code.
|
||||
|
||||
Return rewritten code if successful or None if not.
|
||||
@@ -382,21 +384,21 @@ def _read_pyc(
|
||||
return None
|
||||
# Check for invalid or out of date pyc file.
|
||||
if len(data) != (16):
|
||||
trace(f"_read_pyc({source}): invalid pyc (too short)")
|
||||
trace("_read_pyc(%s): invalid pyc (too short)" % source)
|
||||
return None
|
||||
if data[:4] != importlib.util.MAGIC_NUMBER:
|
||||
trace(f"_read_pyc({source}): invalid pyc (bad magic number)")
|
||||
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
|
||||
return None
|
||||
if data[4:8] != b"\x00\x00\x00\x00":
|
||||
trace(f"_read_pyc({source}): invalid pyc (unsupported flags)")
|
||||
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
|
||||
return None
|
||||
mtime_data = data[8:12]
|
||||
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
|
||||
trace(f"_read_pyc({source}): out of date")
|
||||
trace("_read_pyc(%s): out of date" % source)
|
||||
return None
|
||||
size_data = data[12:16]
|
||||
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
|
||||
trace(f"_read_pyc({source}): invalid pyc (incorrect size)")
|
||||
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
|
||||
return None
|
||||
try:
|
||||
co = marshal.load(fp)
|
||||
@@ -404,7 +406,7 @@ def _read_pyc(
|
||||
trace(f"_read_pyc({source}): marshal.load error {e}")
|
||||
return None
|
||||
if not isinstance(co, types.CodeType):
|
||||
trace(f"_read_pyc({source}): not a code object")
|
||||
trace("_read_pyc(%s): not a code object" % source)
|
||||
return None
|
||||
return co
|
||||
|
||||
@@ -412,8 +414,8 @@ def _read_pyc(
|
||||
def rewrite_asserts(
|
||||
mod: ast.Module,
|
||||
source: bytes,
|
||||
module_path: str | None = None,
|
||||
config: Config | None = None,
|
||||
module_path: Optional[str] = None,
|
||||
config: Optional[Config] = None,
|
||||
) -> None:
|
||||
"""Rewrite the assert statements in mod."""
|
||||
AssertionRewriter(module_path, config, source).run(mod)
|
||||
@@ -429,22 +431,13 @@ def _saferepr(obj: object) -> str:
|
||||
sequences, especially '\n{' and '\n}' are likely to be present in
|
||||
JSON reprs.
|
||||
"""
|
||||
if isinstance(obj, types.MethodType):
|
||||
# for bound methods, skip redundant <bound method ...> information
|
||||
return obj.__name__
|
||||
|
||||
maxsize = _get_maxsize_for_saferepr(util._config)
|
||||
if not maxsize:
|
||||
return saferepr_unlimited(obj).replace("\n", "\\n")
|
||||
return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
|
||||
|
||||
|
||||
def _get_maxsize_for_saferepr(config: Config | None) -> int | None:
|
||||
def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
|
||||
"""Get `maxsize` configuration for saferepr based on the given config object."""
|
||||
if config is None:
|
||||
verbosity = 0
|
||||
else:
|
||||
verbosity = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
|
||||
verbosity = config.getoption("verbose") if config is not None else 0
|
||||
if verbosity >= 2:
|
||||
return None
|
||||
if verbosity >= 1:
|
||||
@@ -465,7 +458,7 @@ def _format_assertmsg(obj: object) -> str:
|
||||
# However in either case we want to preserve the newline.
|
||||
replaces = [("\n", "\n~"), ("%", "%%")]
|
||||
if not isinstance(obj, str):
|
||||
obj = saferepr(obj, _get_maxsize_for_saferepr(util._config))
|
||||
obj = saferepr(obj)
|
||||
replaces.append(("\\n", "\n~"))
|
||||
|
||||
for r1, r2 in replaces:
|
||||
@@ -476,8 +469,7 @@ def _format_assertmsg(obj: object) -> str:
|
||||
|
||||
def _should_repr_global_name(obj: object) -> bool:
|
||||
if callable(obj):
|
||||
# For pytest fixtures the __repr__ method provides more information than the function name.
|
||||
return isinstance(obj, FixtureFunctionDefinition)
|
||||
return False
|
||||
|
||||
try:
|
||||
return not hasattr(obj, "__name__")
|
||||
@@ -486,7 +478,7 @@ def _should_repr_global_name(obj: object) -> bool:
|
||||
|
||||
|
||||
def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
|
||||
explanation = "(" + ((is_or and " or ") or " and ").join(explanations) + ")"
|
||||
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
|
||||
return explanation.replace("%", "%%")
|
||||
|
||||
|
||||
@@ -496,7 +488,7 @@ def _call_reprcompare(
|
||||
expls: Sequence[str],
|
||||
each_obj: Sequence[object],
|
||||
) -> str:
|
||||
for i, res, expl in zip(range(len(ops)), results, expls, strict=True):
|
||||
for i, res, expl in zip(range(len(ops)), results, expls):
|
||||
try:
|
||||
done = not res
|
||||
except Exception:
|
||||
@@ -558,14 +550,14 @@ def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_assertion_exprs(src: bytes) -> dict[int, str]:
|
||||
def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
|
||||
"""Return a mapping from {lineno: "assertion test expression"}."""
|
||||
ret: dict[int, str] = {}
|
||||
ret: Dict[int, str] = {}
|
||||
|
||||
depth = 0
|
||||
lines: list[str] = []
|
||||
assert_lineno: int | None = None
|
||||
seen_lines: set[int] = set()
|
||||
lines: List[str] = []
|
||||
assert_lineno: Optional[int] = None
|
||||
seen_lines: Set[int] = set()
|
||||
|
||||
def _write_and_reset() -> None:
|
||||
nonlocal depth, lines, assert_lineno, seen_lines
|
||||
@@ -599,7 +591,7 @@ def _get_assertion_exprs(src: bytes) -> dict[int, str]:
|
||||
# multi-line assert with message
|
||||
elif lineno in seen_lines:
|
||||
lines[-1] = lines[-1][:offset]
|
||||
# multi line assert with escaped newline before message
|
||||
# multi line assert with escapd newline before message
|
||||
else:
|
||||
lines.append(line[:offset])
|
||||
_write_and_reset()
|
||||
@@ -672,7 +664,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, module_path: str | None, config: Config | None, source: bytes
|
||||
self, module_path: Optional[str], config: Optional[Config], source: bytes
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.module_path = module_path
|
||||
@@ -685,9 +677,9 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
self.enable_assertion_pass_hook = False
|
||||
self.source = source
|
||||
self.scope: tuple[ast.AST, ...] = ()
|
||||
self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = (
|
||||
defaultdict(dict)
|
||||
)
|
||||
self.variables_overwrite: defaultdict[
|
||||
tuple[ast.AST, ...], Dict[str, str]
|
||||
] = defaultdict(dict)
|
||||
|
||||
def run(self, mod: ast.Module) -> None:
|
||||
"""Find all assert statements in *mod* and rewrite them."""
|
||||
@@ -702,18 +694,28 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
if doc is not None and self.is_rewrite_disabled(doc):
|
||||
return
|
||||
pos = 0
|
||||
item = None
|
||||
for item in mod.body:
|
||||
match item:
|
||||
case ast.Expr(value=ast.Constant(value=str() as doc)) if (
|
||||
expect_docstring
|
||||
):
|
||||
if self.is_rewrite_disabled(doc):
|
||||
return
|
||||
expect_docstring = False
|
||||
case ast.ImportFrom(level=0, module="__future__"):
|
||||
pass
|
||||
case _:
|
||||
break
|
||||
if (
|
||||
expect_docstring
|
||||
and isinstance(item, ast.Expr)
|
||||
and isinstance(item.value, astStr)
|
||||
):
|
||||
if sys.version_info >= (3, 8):
|
||||
doc = item.value.value
|
||||
else:
|
||||
doc = item.value.s
|
||||
if self.is_rewrite_disabled(doc):
|
||||
return
|
||||
expect_docstring = False
|
||||
elif (
|
||||
isinstance(item, ast.ImportFrom)
|
||||
and item.level == 0
|
||||
and item.module == "__future__"
|
||||
):
|
||||
pass
|
||||
else:
|
||||
break
|
||||
pos += 1
|
||||
# Special case: for a decorated function, set the lineno to that of the
|
||||
# first decorator, not the `def`. Issue #4984.
|
||||
@@ -722,15 +724,21 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
else:
|
||||
lineno = item.lineno
|
||||
# Now actually insert the special imports.
|
||||
aliases = [
|
||||
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
|
||||
ast.alias(
|
||||
"_pytest.assertion.rewrite",
|
||||
"@pytest_ar",
|
||||
lineno=lineno,
|
||||
col_offset=0,
|
||||
),
|
||||
]
|
||||
if sys.version_info >= (3, 10):
|
||||
aliases = [
|
||||
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
|
||||
ast.alias(
|
||||
"_pytest.assertion.rewrite",
|
||||
"@pytest_ar",
|
||||
lineno=lineno,
|
||||
col_offset=0,
|
||||
),
|
||||
]
|
||||
else:
|
||||
aliases = [
|
||||
ast.alias("builtins", "@py_builtins"),
|
||||
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
|
||||
]
|
||||
imports = [
|
||||
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
|
||||
]
|
||||
@@ -738,10 +746,10 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
|
||||
# Collect asserts.
|
||||
self.scope = (mod,)
|
||||
nodes: list[ast.AST | Sentinel] = [mod]
|
||||
nodes: List[Union[ast.AST, Sentinel]] = [mod]
|
||||
while nodes:
|
||||
node = nodes.pop()
|
||||
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
self.scope = tuple((*self.scope, node))
|
||||
nodes.append(_SCOPE_END_MARKER)
|
||||
if node == _SCOPE_END_MARKER:
|
||||
@@ -750,7 +758,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
assert isinstance(node, ast.AST)
|
||||
for name, field in ast.iter_fields(node):
|
||||
if isinstance(field, list):
|
||||
new: list[ast.AST] = []
|
||||
new: List[ast.AST] = []
|
||||
for i, child in enumerate(field):
|
||||
if isinstance(child, ast.Assert):
|
||||
# Transform assert.
|
||||
@@ -783,7 +791,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
"""Give *expr* a name."""
|
||||
name = self.variable()
|
||||
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
|
||||
return ast.copy_location(ast.Name(name, ast.Load()), expr)
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
def display(self, expr: ast.expr) -> ast.expr:
|
||||
"""Call saferepr on the expression."""
|
||||
@@ -822,7 +830,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
to format a string of %-formatted values as added by
|
||||
.explanation_param().
|
||||
"""
|
||||
self.explanation_specifiers: dict[str, ast.expr] = {}
|
||||
self.explanation_specifiers: Dict[str, ast.expr] = {}
|
||||
self.stack.append(self.explanation_specifiers)
|
||||
|
||||
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
|
||||
@@ -836,7 +844,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
current = self.stack.pop()
|
||||
if self.stack:
|
||||
self.explanation_specifiers = self.stack[-1]
|
||||
keys: list[ast.expr | None] = [ast.Constant(key) for key in current.keys()]
|
||||
keys = [astStr(key) for key in current.keys()]
|
||||
format_dict = ast.Dict(keys, list(current.values()))
|
||||
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
|
||||
name = "@py_format" + str(next(self.variable_counter))
|
||||
@@ -845,13 +853,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
def generic_visit(self, node: ast.AST) -> tuple[ast.Name, str]:
|
||||
def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
|
||||
"""Handle expressions we don't have custom code for."""
|
||||
assert isinstance(node, ast.expr)
|
||||
res = self.assign(node)
|
||||
return res, self.explanation_param(self.display(res))
|
||||
|
||||
def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
|
||||
def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
|
||||
"""Return the AST statements to replace the ast.Assert instance.
|
||||
|
||||
This rewrites the test of an assertion to provide
|
||||
@@ -860,9 +868,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
the expression is false.
|
||||
"""
|
||||
if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
|
||||
import warnings
|
||||
|
||||
from _pytest.warning_types import PytestAssertRewriteWarning
|
||||
import warnings
|
||||
|
||||
# TODO: This assert should not be needed.
|
||||
assert self.module_path is not None
|
||||
@@ -875,15 +882,15 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
lineno=assert_.lineno,
|
||||
)
|
||||
|
||||
self.statements: list[ast.stmt] = []
|
||||
self.variables: list[str] = []
|
||||
self.statements: List[ast.stmt] = []
|
||||
self.variables: List[str] = []
|
||||
self.variable_counter = itertools.count()
|
||||
|
||||
if self.enable_assertion_pass_hook:
|
||||
self.format_variables: list[str] = []
|
||||
self.format_variables: List[str] = []
|
||||
|
||||
self.stack: list[dict[str, ast.expr]] = []
|
||||
self.expl_stmts: list[ast.stmt] = []
|
||||
self.stack: List[Dict[str, ast.expr]] = []
|
||||
self.expl_stmts: List[ast.stmt] = []
|
||||
self.push_format_context()
|
||||
# Rewrite assert into a bunch of statements.
|
||||
top_condition, explanation = self.visit(assert_.test)
|
||||
@@ -891,16 +898,16 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
negation = ast.UnaryOp(ast.Not(), top_condition)
|
||||
|
||||
if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
|
||||
msg = self.pop_format_context(ast.Constant(explanation))
|
||||
msg = self.pop_format_context(astStr(explanation))
|
||||
|
||||
# Failed
|
||||
if assert_.msg:
|
||||
assertmsg = self.helper("_format_assertmsg", assert_.msg)
|
||||
gluestr = "\n>assert "
|
||||
else:
|
||||
assertmsg = ast.Constant("")
|
||||
assertmsg = astStr("")
|
||||
gluestr = "assert "
|
||||
err_explanation = ast.BinOp(ast.Constant(gluestr), ast.Add(), msg)
|
||||
err_explanation = ast.BinOp(astStr(gluestr), ast.Add(), msg)
|
||||
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
|
||||
err_name = ast.Name("AssertionError", ast.Load())
|
||||
fmt = self.helper("_format_explanation", err_msg)
|
||||
@@ -916,27 +923,27 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
hook_call_pass = ast.Expr(
|
||||
self.helper(
|
||||
"_call_assertion_pass",
|
||||
ast.Constant(assert_.lineno),
|
||||
ast.Constant(orig),
|
||||
astNum(assert_.lineno),
|
||||
astStr(orig),
|
||||
fmt_pass,
|
||||
)
|
||||
)
|
||||
# If any hooks implement assert_pass hook
|
||||
hook_impl_test = ast.If(
|
||||
self.helper("_check_if_assertion_pass_impl"),
|
||||
[*self.expl_stmts, hook_call_pass],
|
||||
self.expl_stmts + [hook_call_pass],
|
||||
[],
|
||||
)
|
||||
statements_pass: list[ast.stmt] = [hook_impl_test]
|
||||
statements_pass = [hook_impl_test]
|
||||
|
||||
# Test for assertion condition
|
||||
main_test = ast.If(negation, statements_fail, statements_pass)
|
||||
self.statements.append(main_test)
|
||||
if self.format_variables:
|
||||
variables: list[ast.expr] = [
|
||||
variables = [
|
||||
ast.Name(name, ast.Store()) for name in self.format_variables
|
||||
]
|
||||
clear_format = ast.Assign(variables, ast.Constant(None))
|
||||
clear_format = ast.Assign(variables, astNameConstant(None))
|
||||
self.statements.append(clear_format)
|
||||
|
||||
else: # Original assertion rewriting
|
||||
@@ -947,9 +954,9 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
assertmsg = self.helper("_format_assertmsg", assert_.msg)
|
||||
explanation = "\n>assert " + explanation
|
||||
else:
|
||||
assertmsg = ast.Constant("")
|
||||
assertmsg = astStr("")
|
||||
explanation = "assert " + explanation
|
||||
template = ast.BinOp(assertmsg, ast.Add(), ast.Constant(explanation))
|
||||
template = ast.BinOp(assertmsg, ast.Add(), astStr(explanation))
|
||||
msg = self.pop_format_context(template)
|
||||
fmt = self.helper("_format_explanation", msg)
|
||||
err_name = ast.Name("AssertionError", ast.Load())
|
||||
@@ -961,40 +968,37 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
# Clear temporary variables by setting them to None.
|
||||
if self.variables:
|
||||
variables = [ast.Name(name, ast.Store()) for name in self.variables]
|
||||
clear = ast.Assign(variables, ast.Constant(None))
|
||||
clear = ast.Assign(variables, astNameConstant(None))
|
||||
self.statements.append(clear)
|
||||
# Fix locations (line numbers/column offsets).
|
||||
for stmt in self.statements:
|
||||
for node in traverse_node(stmt):
|
||||
if getattr(node, "lineno", None) is None:
|
||||
# apply the assertion location to all generated ast nodes without source location
|
||||
# and preserve the location of existing nodes or generated nodes with an correct location.
|
||||
ast.copy_location(node, assert_)
|
||||
ast.copy_location(node, assert_)
|
||||
return self.statements
|
||||
|
||||
def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
|
||||
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
|
||||
# This method handles the 'walrus operator' repr of the target
|
||||
# name if it's a local variable or _should_repr_global_name()
|
||||
# thinks it's acceptable.
|
||||
locs = ast.Call(self.builtin("locals"), [], [])
|
||||
target_id = name.target.id
|
||||
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
|
||||
target_id = name.target.id # type: ignore[attr-defined]
|
||||
inlocs = ast.Compare(astStr(target_id), [ast.In()], [locs])
|
||||
dorepr = self.helper("_should_repr_global_name", name)
|
||||
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
|
||||
expr = ast.IfExp(test, self.display(name), ast.Constant(target_id))
|
||||
expr = ast.IfExp(test, self.display(name), astStr(target_id))
|
||||
return name, self.explanation_param(expr)
|
||||
|
||||
def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]:
|
||||
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
|
||||
# Display the repr of the name if it's a local variable or
|
||||
# _should_repr_global_name() thinks it's acceptable.
|
||||
locs = ast.Call(self.builtin("locals"), [], [])
|
||||
inlocs = ast.Compare(ast.Constant(name.id), [ast.In()], [locs])
|
||||
inlocs = ast.Compare(astStr(name.id), [ast.In()], [locs])
|
||||
dorepr = self.helper("_should_repr_global_name", name)
|
||||
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
|
||||
expr = ast.IfExp(test, self.display(name), ast.Constant(name.id))
|
||||
expr = ast.IfExp(test, self.display(name), astStr(name.id))
|
||||
return name, self.explanation_param(expr)
|
||||
|
||||
def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
|
||||
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
|
||||
res_var = self.variable()
|
||||
expl_list = self.assign(ast.List([], ast.Load()))
|
||||
app = ast.Attribute(expl_list, "append", ast.Load())
|
||||
@@ -1006,57 +1010,60 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
# Process each operand, short-circuiting if needed.
|
||||
for i, v in enumerate(boolop.values):
|
||||
if i:
|
||||
fail_inner: list[ast.stmt] = []
|
||||
fail_inner: List[ast.stmt] = []
|
||||
# cond is set in a prior loop iteration below
|
||||
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821
|
||||
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
||||
self.expl_stmts = fail_inner
|
||||
match v:
|
||||
# Check if the left operand is an ast.NamedExpr and the value has already been visited
|
||||
case ast.Compare(
|
||||
left=ast.NamedExpr(target=ast.Name(id=target_id))
|
||||
) if target_id in [
|
||||
e.id for e in boolop.values[:i] if hasattr(e, "id")
|
||||
]:
|
||||
pytest_temp = self.variable()
|
||||
self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment]
|
||||
# mypy's false positive, we're checking that the 'target' attribute exists.
|
||||
v.left.target.id = pytest_temp # type:ignore[attr-defined]
|
||||
# Check if the left operand is a namedExpr and the value has already been visited
|
||||
if (
|
||||
isinstance(v, ast.Compare)
|
||||
and isinstance(v.left, namedExpr)
|
||||
and v.left.target.id
|
||||
in [
|
||||
ast_expr.id
|
||||
for ast_expr in boolop.values[:i]
|
||||
if hasattr(ast_expr, "id")
|
||||
]
|
||||
):
|
||||
pytest_temp = self.variable()
|
||||
self.variables_overwrite[self.scope][
|
||||
v.left.target.id
|
||||
] = v.left # type:ignore[assignment]
|
||||
v.left.target.id = pytest_temp
|
||||
self.push_format_context()
|
||||
res, expl = self.visit(v)
|
||||
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
||||
expl_format = self.pop_format_context(ast.Constant(expl))
|
||||
expl_format = self.pop_format_context(astStr(expl))
|
||||
call = ast.Call(app, [expl_format], [])
|
||||
self.expl_stmts.append(ast.Expr(call))
|
||||
if i < levels:
|
||||
cond: ast.expr = res
|
||||
if is_or:
|
||||
cond = ast.UnaryOp(ast.Not(), cond)
|
||||
inner: list[ast.stmt] = []
|
||||
inner: List[ast.stmt] = []
|
||||
self.statements.append(ast.If(cond, inner, []))
|
||||
self.statements = body = inner
|
||||
self.statements = save
|
||||
self.expl_stmts = fail_save
|
||||
expl_template = self.helper("_format_boolop", expl_list, ast.Constant(is_or))
|
||||
expl_template = self.helper("_format_boolop", expl_list, astNum(is_or))
|
||||
expl = self.pop_format_context(expl_template)
|
||||
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
||||
|
||||
def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]:
|
||||
def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
|
||||
pattern = UNARY_MAP[unary.op.__class__]
|
||||
operand_res, operand_expl = self.visit(unary.operand)
|
||||
res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary))
|
||||
res = self.assign(ast.UnaryOp(unary.op, operand_res))
|
||||
return res, pattern % (operand_expl,)
|
||||
|
||||
def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
|
||||
def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
|
||||
symbol = BINOP_MAP[binop.op.__class__]
|
||||
left_expr, left_expl = self.visit(binop.left)
|
||||
right_expr, right_expl = self.visit(binop.right)
|
||||
explanation = f"({left_expl} {symbol} {right_expl})"
|
||||
res = self.assign(
|
||||
ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop)
|
||||
)
|
||||
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
|
||||
return res, explanation
|
||||
|
||||
def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
|
||||
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
||||
new_func, func_expl = self.visit(call.func)
|
||||
arg_expls = []
|
||||
new_args = []
|
||||
@@ -1065,16 +1072,19 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
|
||||
self.scope, {}
|
||||
):
|
||||
arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment]
|
||||
arg = self.variables_overwrite[self.scope][
|
||||
arg.id
|
||||
] # type:ignore[assignment]
|
||||
res, expl = self.visit(arg)
|
||||
arg_expls.append(expl)
|
||||
new_args.append(res)
|
||||
for keyword in call.keywords:
|
||||
match keyword.value:
|
||||
case ast.Name(id=id) if id in self.variables_overwrite.get(
|
||||
self.scope, {}
|
||||
):
|
||||
keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment]
|
||||
if isinstance(
|
||||
keyword.value, ast.Name
|
||||
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
|
||||
keyword.value = self.variables_overwrite[self.scope][
|
||||
keyword.value.id
|
||||
] # type:ignore[assignment]
|
||||
res, expl = self.visit(keyword.value)
|
||||
new_kwargs.append(ast.keyword(keyword.arg, res))
|
||||
if keyword.arg:
|
||||
@@ -1083,68 +1093,70 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
arg_expls.append("**" + expl)
|
||||
|
||||
expl = "{}({})".format(func_expl, ", ".join(arg_expls))
|
||||
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call)
|
||||
new_call = ast.Call(new_func, new_args, new_kwargs)
|
||||
res = self.assign(new_call)
|
||||
res_expl = self.explanation_param(self.display(res))
|
||||
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
|
||||
return res, outer_expl
|
||||
|
||||
def visit_Starred(self, starred: ast.Starred) -> tuple[ast.Starred, str]:
|
||||
def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
|
||||
# A Starred node can appear in a function call.
|
||||
res, expl = self.visit(starred.value)
|
||||
new_starred = ast.Starred(res, starred.ctx)
|
||||
return new_starred, "*" + expl
|
||||
|
||||
def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
|
||||
def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
|
||||
if not isinstance(attr.ctx, ast.Load):
|
||||
return self.generic_visit(attr)
|
||||
value, value_expl = self.visit(attr.value)
|
||||
res = self.assign(
|
||||
ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr)
|
||||
)
|
||||
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
|
||||
res_expl = self.explanation_param(self.display(res))
|
||||
pat = "%s\n{%s = %s.%s\n}"
|
||||
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
|
||||
return res, expl
|
||||
|
||||
def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
|
||||
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
||||
self.push_format_context()
|
||||
# We first check if we have overwritten a variable in the previous assert
|
||||
match comp.left:
|
||||
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
|
||||
self.scope, {}
|
||||
):
|
||||
comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment]
|
||||
case ast.NamedExpr(target=ast.Name(id=target_id)):
|
||||
self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment]
|
||||
if isinstance(
|
||||
comp.left, ast.Name
|
||||
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
|
||||
comp.left = self.variables_overwrite[self.scope][
|
||||
comp.left.id
|
||||
] # type:ignore[assignment]
|
||||
if isinstance(comp.left, namedExpr):
|
||||
self.variables_overwrite[self.scope][
|
||||
comp.left.target.id
|
||||
] = comp.left # type:ignore[assignment]
|
||||
left_res, left_expl = self.visit(comp.left)
|
||||
if isinstance(comp.left, ast.Compare | ast.BoolOp):
|
||||
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
||||
left_expl = f"({left_expl})"
|
||||
res_variables = [self.variable() for i in range(len(comp.ops))]
|
||||
load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables]
|
||||
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
|
||||
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
|
||||
it = zip(range(len(comp.ops)), comp.ops, comp.comparators, strict=True)
|
||||
expls: list[ast.expr] = []
|
||||
syms: list[ast.expr] = []
|
||||
it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
|
||||
expls = []
|
||||
syms = []
|
||||
results = [left_res]
|
||||
for i, op, next_operand in it:
|
||||
match (next_operand, left_res):
|
||||
case (
|
||||
ast.NamedExpr(target=ast.Name(id=target_id)),
|
||||
ast.Name(id=name_id),
|
||||
) if target_id == name_id:
|
||||
next_operand.target.id = self.variable()
|
||||
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]
|
||||
|
||||
if (
|
||||
isinstance(next_operand, namedExpr)
|
||||
and isinstance(left_res, ast.Name)
|
||||
and next_operand.target.id == left_res.id
|
||||
):
|
||||
next_operand.target.id = self.variable()
|
||||
self.variables_overwrite[self.scope][
|
||||
left_res.id
|
||||
] = next_operand # type:ignore[assignment]
|
||||
next_res, next_expl = self.visit(next_operand)
|
||||
if isinstance(next_operand, ast.Compare | ast.BoolOp):
|
||||
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
|
||||
next_expl = f"({next_expl})"
|
||||
results.append(next_res)
|
||||
sym = BINOP_MAP[op.__class__]
|
||||
syms.append(ast.Constant(sym))
|
||||
syms.append(astStr(sym))
|
||||
expl = f"{left_expl} {sym} {next_expl}"
|
||||
expls.append(ast.Constant(expl))
|
||||
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp)
|
||||
expls.append(astStr(expl))
|
||||
res_expr = ast.Compare(left_res, [op], [next_res])
|
||||
self.statements.append(ast.Assign([store_names[i]], res_expr))
|
||||
left_res, left_expl = next_res, next_expl
|
||||
# Use pytest.assertion.util._reprcompare if that's available.
|
||||
@@ -1179,10 +1191,7 @@ def try_makedirs(cache_dir: Path) -> bool:
|
||||
return False
|
||||
except OSError as e:
|
||||
# as of now, EROFS doesn't have an equivalent OSError-subclass
|
||||
#
|
||||
# squashfuse_ll returns ENOSYS "OSError: [Errno 38] Function not
|
||||
# implemented" for a read-only error
|
||||
if e.errno in {errno.EROFS, errno.ENOSYS}:
|
||||
if e.errno == errno.EROFS:
|
||||
return False
|
||||
raise
|
||||
return True
|
||||
@@ -1190,7 +1199,7 @@ def try_makedirs(cache_dir: Path) -> bool:
|
||||
|
||||
def get_cache_dir(file_path: Path) -> Path:
|
||||
"""Return the cache directory to write .pyc files for the given .py file path."""
|
||||
if sys.pycache_prefix:
|
||||
if sys.version_info >= (3, 8) and sys.pycache_prefix:
|
||||
# given:
|
||||
# prefix = '/tmp/pycs'
|
||||
# path = '/home/user/proj/test_app.py'
|
||||
|
||||
Reference in New Issue
Block a user