updates
This commit is contained in:
@@ -1,12 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Support for presenting detailed information in failing assertions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Protocol
|
||||
from typing import Generator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from _pytest.assertion import rewrite
|
||||
@@ -18,7 +15,6 @@ from _pytest.config import hookimpl
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.nodes import Item
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest.main import Session
|
||||
|
||||
@@ -47,26 +43,6 @@ def pytest_addoption(parser: Parser) -> None:
|
||||
"Make sure to delete any previously generated pyc cache files.",
|
||||
)
|
||||
|
||||
parser.addini(
|
||||
"truncation_limit_lines",
|
||||
default=None,
|
||||
help="Set threshold of LINES after which truncation will take effect",
|
||||
)
|
||||
parser.addini(
|
||||
"truncation_limit_chars",
|
||||
default=None,
|
||||
help=("Set threshold of CHARS after which truncation will take effect"),
|
||||
)
|
||||
|
||||
Config._add_verbosity_ini(
|
||||
parser,
|
||||
Config.VERBOSITY_ASSERTIONS,
|
||||
help=(
|
||||
"Specify a verbosity level for assertions, overriding the main level. "
|
||||
"Higher levels will provide more detailed explanation when an assertion fails."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def register_assert_rewrite(*names: str) -> None:
|
||||
"""Register one or more module names to be rewritten on import.
|
||||
@@ -83,18 +59,15 @@ def register_assert_rewrite(*names: str) -> None:
|
||||
if not isinstance(name, str):
|
||||
msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
|
||||
raise TypeError(msg.format(repr(names)))
|
||||
rewrite_hook: RewriteHook
|
||||
for hook in sys.meta_path:
|
||||
if isinstance(hook, rewrite.AssertionRewritingHook):
|
||||
rewrite_hook = hook
|
||||
importhook = hook
|
||||
break
|
||||
else:
|
||||
rewrite_hook = DummyRewriteHook()
|
||||
rewrite_hook.mark_rewrite(*names)
|
||||
|
||||
|
||||
class RewriteHook(Protocol):
|
||||
def mark_rewrite(self, *names: str) -> None: ...
|
||||
# TODO(typing): Add a protocol for mark_rewrite() and use it
|
||||
# for importhook and for PytestPluginManager.rewrite_hook.
|
||||
importhook = DummyRewriteHook() # type: ignore
|
||||
importhook.mark_rewrite(*names)
|
||||
|
||||
|
||||
class DummyRewriteHook:
|
||||
@@ -110,7 +83,7 @@ class AssertionState:
|
||||
def __init__(self, config: Config, mode) -> None:
|
||||
self.mode = mode
|
||||
self.trace = config.trace.root.get("assertion")
|
||||
self.hook: rewrite.AssertionRewritingHook | None = None
|
||||
self.hook: Optional[rewrite.AssertionRewritingHook] = None
|
||||
|
||||
|
||||
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
|
||||
@@ -129,7 +102,7 @@ def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
|
||||
return hook
|
||||
|
||||
|
||||
def pytest_collection(session: Session) -> None:
|
||||
def pytest_collection(session: "Session") -> None:
|
||||
# This hook is only called when test modules are collected
|
||||
# so for example not in the managing process of pytest-xdist
|
||||
# (which does not collect test modules).
|
||||
@@ -139,17 +112,18 @@ def pytest_collection(session: Session) -> None:
|
||||
assertstate.hook.set_session(session)
|
||||
|
||||
|
||||
@hookimpl(wrapper=True, tryfirst=True)
|
||||
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
|
||||
@hookimpl(tryfirst=True, hookwrapper=True)
|
||||
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
|
||||
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
|
||||
|
||||
The rewrite module will use util._reprcompare if it exists to use custom
|
||||
reporting via the pytest_assertrepr_compare hook. This sets up this custom
|
||||
comparison for the test.
|
||||
"""
|
||||
|
||||
ihook = item.ihook
|
||||
|
||||
def callbinrepr(op, left: object, right: object) -> str | None:
|
||||
def callbinrepr(op, left: object, right: object) -> Optional[str]:
|
||||
"""Call the pytest_assertrepr_compare hook and prepare the result.
|
||||
|
||||
This uses the first result from the hook and then ensures the
|
||||
@@ -188,14 +162,13 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
|
||||
|
||||
util._assertion_pass = call_assertion_pass_hook
|
||||
|
||||
try:
|
||||
return (yield)
|
||||
finally:
|
||||
util._reprcompare, util._assertion_pass = saved_assert_hooks
|
||||
util._config = None
|
||||
yield
|
||||
|
||||
util._reprcompare, util._assertion_pass = saved_assert_hooks
|
||||
util._config = None
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: Session) -> None:
|
||||
def pytest_sessionfinish(session: "Session") -> None:
|
||||
assertstate = session.config.stash.get(assertstate_key, None)
|
||||
if assertstate:
|
||||
if assertstate.hook is not None:
|
||||
@@ -204,5 +177,5 @@ def pytest_sessionfinish(session: Session) -> None:
|
||||
|
||||
def pytest_assertrepr_compare(
|
||||
config: Config, op: str, left: Any, right: Any
|
||||
) -> list[str] | None:
|
||||
) -> Optional[List[str]]:
|
||||
return util.assertrepr_compare(config=config, op=op, left=left, right=right)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,13 +1,5 @@
|
||||
"""Rewrite assertion AST to produce nice error messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
import errno
|
||||
import functools
|
||||
import importlib.abc
|
||||
@@ -17,46 +9,53 @@ import io
|
||||
import itertools
|
||||
import marshal
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pathlib import PurePath
|
||||
import struct
|
||||
import sys
|
||||
import tokenize
|
||||
import types
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from pathlib import PurePath
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import IO
|
||||
from typing import Iterable
|
||||
from typing import Iterator
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Set
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from importlib.resources.abc import TraversableResources
|
||||
else:
|
||||
from importlib.abc import TraversableResources
|
||||
if sys.version_info < (3, 11):
|
||||
from importlib.readers import FileReader
|
||||
else:
|
||||
from importlib.resources.readers import FileReader
|
||||
|
||||
from typing import Union
|
||||
|
||||
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
|
||||
from _pytest._io.saferepr import saferepr
|
||||
from _pytest._io.saferepr import saferepr_unlimited
|
||||
from _pytest._version import version
|
||||
from _pytest.assertion import util
|
||||
from _pytest.assertion.util import ( # noqa: F401
|
||||
format_explanation as _format_explanation,
|
||||
)
|
||||
from _pytest.config import Config
|
||||
from _pytest.fixtures import FixtureFunctionDefinition
|
||||
from _pytest.main import Session
|
||||
from _pytest.pathlib import absolutepath
|
||||
from _pytest.pathlib import fnmatch_ex
|
||||
from _pytest.stash import StashKey
|
||||
|
||||
|
||||
# fmt: off
|
||||
from _pytest.assertion.util import format_explanation as _format_explanation # noqa:F401, isort:skip
|
||||
# fmt:on
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest.assertion import AssertionState
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
namedExpr = ast.NamedExpr
|
||||
astNameConstant = ast.Constant
|
||||
astStr = ast.Constant
|
||||
astNum = ast.Constant
|
||||
else:
|
||||
namedExpr = ast.Expr
|
||||
astNameConstant = ast.NameConstant
|
||||
astStr = ast.Str
|
||||
astNum = ast.Num
|
||||
|
||||
|
||||
class Sentinel:
|
||||
pass
|
||||
@@ -66,7 +65,7 @@ assertstate_key = StashKey["AssertionState"]()
|
||||
|
||||
# pytest caches rewritten pycs in pycache dirs
|
||||
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
|
||||
PYC_EXT = ".py" + ((__debug__ and "c") or "o")
|
||||
PYC_EXT = ".py" + (__debug__ and "c" or "o")
|
||||
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
|
||||
|
||||
# Special marker that denotes we have just left a scope definition
|
||||
@@ -82,17 +81,17 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
self.fnpats = config.getini("python_files")
|
||||
except ValueError:
|
||||
self.fnpats = ["test_*.py", "*_test.py"]
|
||||
self.session: Session | None = None
|
||||
self._rewritten_names: dict[str, Path] = {}
|
||||
self._must_rewrite: set[str] = set()
|
||||
self.session: Optional[Session] = None
|
||||
self._rewritten_names: Dict[str, Path] = {}
|
||||
self._must_rewrite: Set[str] = set()
|
||||
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
|
||||
# which might result in infinite recursion (#3506)
|
||||
self._writing_pyc = False
|
||||
self._basenames_to_check_rewrite = {"conftest"}
|
||||
self._marked_for_rewrite_cache: dict[str, bool] = {}
|
||||
self._marked_for_rewrite_cache: Dict[str, bool] = {}
|
||||
self._session_paths_checked = False
|
||||
|
||||
def set_session(self, session: Session | None) -> None:
|
||||
def set_session(self, session: Optional[Session]) -> None:
|
||||
self.session = session
|
||||
self._session_paths_checked = False
|
||||
|
||||
@@ -102,28 +101,18 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
def find_spec(
|
||||
self,
|
||||
name: str,
|
||||
path: Sequence[str | bytes] | None = None,
|
||||
target: types.ModuleType | None = None,
|
||||
) -> importlib.machinery.ModuleSpec | None:
|
||||
path: Optional[Sequence[Union[str, bytes]]] = None,
|
||||
target: Optional[types.ModuleType] = None,
|
||||
) -> Optional[importlib.machinery.ModuleSpec]:
|
||||
if self._writing_pyc:
|
||||
return None
|
||||
state = self.config.stash[assertstate_key]
|
||||
if self._early_rewrite_bailout(name, state):
|
||||
return None
|
||||
state.trace(f"find_module called for: {name}")
|
||||
state.trace("find_module called for: %s" % name)
|
||||
|
||||
# Type ignored because mypy is confused about the `self` binding here.
|
||||
spec = self._find_spec(name, path) # type: ignore
|
||||
|
||||
if spec is None and path is not None:
|
||||
# With --import-mode=importlib, PathFinder cannot find spec without modifying `sys.path`,
|
||||
# causing inability to assert rewriting (#12659).
|
||||
# At this point, try using the file path to find the module spec.
|
||||
for _path_str in path:
|
||||
spec = importlib.util.spec_from_file_location(name, _path_str)
|
||||
if spec is not None:
|
||||
break
|
||||
|
||||
if (
|
||||
# the import machinery could not find a file to import
|
||||
spec is None
|
||||
@@ -151,7 +140,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
|
||||
def create_module(
|
||||
self, spec: importlib.machinery.ModuleSpec
|
||||
) -> types.ModuleType | None:
|
||||
) -> Optional[types.ModuleType]:
|
||||
return None # default behaviour is fine
|
||||
|
||||
def exec_module(self, module: types.ModuleType) -> None:
|
||||
@@ -196,7 +185,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
state.trace(f"found cached rewritten pyc for {fn}")
|
||||
exec(co, module.__dict__)
|
||||
|
||||
def _early_rewrite_bailout(self, name: str, state: AssertionState) -> bool:
|
||||
def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
|
||||
"""A fast way to get out of rewriting modules.
|
||||
|
||||
Profiling has shown that the call to PathFinder.find_spec (inside of
|
||||
@@ -235,7 +224,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
state.trace(f"early skip of rewriting module: {name}")
|
||||
return True
|
||||
|
||||
def _should_rewrite(self, name: str, fn: str, state: AssertionState) -> bool:
|
||||
def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
|
||||
# always rewrite conftest files
|
||||
if os.path.basename(fn) == "conftest.py":
|
||||
state.trace(f"rewriting conftest file: {fn!r}")
|
||||
@@ -256,7 +245,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
|
||||
return self._is_marked_for_rewrite(name, state)
|
||||
|
||||
def _is_marked_for_rewrite(self, name: str, state: AssertionState) -> bool:
|
||||
def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
|
||||
try:
|
||||
return self._marked_for_rewrite_cache[name]
|
||||
except KeyError:
|
||||
@@ -292,18 +281,31 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||
|
||||
self.config.issue_config_time_warning(
|
||||
PytestAssertRewriteWarning(
|
||||
f"Module already imported so cannot be rewritten; {name}"
|
||||
"Module already imported so cannot be rewritten: %s" % name
|
||||
),
|
||||
stacklevel=5,
|
||||
)
|
||||
|
||||
def get_data(self, pathname: str | bytes) -> bytes:
|
||||
def get_data(self, pathname: Union[str, bytes]) -> bytes:
|
||||
"""Optional PEP302 get_data API."""
|
||||
with open(pathname, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def get_resource_reader(self, name: str) -> TraversableResources:
|
||||
return FileReader(types.SimpleNamespace(path=self._rewritten_names[name])) # type: ignore[arg-type]
|
||||
if sys.version_info >= (3, 10):
|
||||
if sys.version_info >= (3, 12):
|
||||
from importlib.resources.abc import TraversableResources
|
||||
else:
|
||||
from importlib.abc import TraversableResources
|
||||
|
||||
def get_resource_reader(self, name: str) -> TraversableResources: # type: ignore
|
||||
if sys.version_info < (3, 11):
|
||||
from importlib.readers import FileReader
|
||||
else:
|
||||
from importlib.resources.readers import FileReader
|
||||
|
||||
return FileReader( # type:ignore[no-any-return]
|
||||
types.SimpleNamespace(path=self._rewritten_names[name])
|
||||
)
|
||||
|
||||
|
||||
def _write_pyc_fp(
|
||||
@@ -325,7 +327,7 @@ def _write_pyc_fp(
|
||||
|
||||
|
||||
def _write_pyc(
|
||||
state: AssertionState,
|
||||
state: "AssertionState",
|
||||
co: types.CodeType,
|
||||
source_stat: os.stat_result,
|
||||
pyc: Path,
|
||||
@@ -349,7 +351,7 @@ def _write_pyc(
|
||||
return True
|
||||
|
||||
|
||||
def _rewrite_test(fn: Path, config: Config) -> tuple[os.stat_result, types.CodeType]:
|
||||
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
|
||||
"""Read and rewrite *fn* and return the code object."""
|
||||
stat = os.stat(fn)
|
||||
source = fn.read_bytes()
|
||||
@@ -362,7 +364,7 @@ def _rewrite_test(fn: Path, config: Config) -> tuple[os.stat_result, types.CodeT
|
||||
|
||||
def _read_pyc(
|
||||
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
|
||||
) -> types.CodeType | None:
|
||||
) -> Optional[types.CodeType]:
|
||||
"""Possibly read a pytest pyc containing rewritten code.
|
||||
|
||||
Return rewritten code if successful or None if not.
|
||||
@@ -382,21 +384,21 @@ def _read_pyc(
|
||||
return None
|
||||
# Check for invalid or out of date pyc file.
|
||||
if len(data) != (16):
|
||||
trace(f"_read_pyc({source}): invalid pyc (too short)")
|
||||
trace("_read_pyc(%s): invalid pyc (too short)" % source)
|
||||
return None
|
||||
if data[:4] != importlib.util.MAGIC_NUMBER:
|
||||
trace(f"_read_pyc({source}): invalid pyc (bad magic number)")
|
||||
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
|
||||
return None
|
||||
if data[4:8] != b"\x00\x00\x00\x00":
|
||||
trace(f"_read_pyc({source}): invalid pyc (unsupported flags)")
|
||||
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
|
||||
return None
|
||||
mtime_data = data[8:12]
|
||||
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
|
||||
trace(f"_read_pyc({source}): out of date")
|
||||
trace("_read_pyc(%s): out of date" % source)
|
||||
return None
|
||||
size_data = data[12:16]
|
||||
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
|
||||
trace(f"_read_pyc({source}): invalid pyc (incorrect size)")
|
||||
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
|
||||
return None
|
||||
try:
|
||||
co = marshal.load(fp)
|
||||
@@ -404,7 +406,7 @@ def _read_pyc(
|
||||
trace(f"_read_pyc({source}): marshal.load error {e}")
|
||||
return None
|
||||
if not isinstance(co, types.CodeType):
|
||||
trace(f"_read_pyc({source}): not a code object")
|
||||
trace("_read_pyc(%s): not a code object" % source)
|
||||
return None
|
||||
return co
|
||||
|
||||
@@ -412,8 +414,8 @@ def _read_pyc(
|
||||
def rewrite_asserts(
|
||||
mod: ast.Module,
|
||||
source: bytes,
|
||||
module_path: str | None = None,
|
||||
config: Config | None = None,
|
||||
module_path: Optional[str] = None,
|
||||
config: Optional[Config] = None,
|
||||
) -> None:
|
||||
"""Rewrite the assert statements in mod."""
|
||||
AssertionRewriter(module_path, config, source).run(mod)
|
||||
@@ -429,22 +431,13 @@ def _saferepr(obj: object) -> str:
|
||||
sequences, especially '\n{' and '\n}' are likely to be present in
|
||||
JSON reprs.
|
||||
"""
|
||||
if isinstance(obj, types.MethodType):
|
||||
# for bound methods, skip redundant <bound method ...> information
|
||||
return obj.__name__
|
||||
|
||||
maxsize = _get_maxsize_for_saferepr(util._config)
|
||||
if not maxsize:
|
||||
return saferepr_unlimited(obj).replace("\n", "\\n")
|
||||
return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
|
||||
|
||||
|
||||
def _get_maxsize_for_saferepr(config: Config | None) -> int | None:
|
||||
def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
|
||||
"""Get `maxsize` configuration for saferepr based on the given config object."""
|
||||
if config is None:
|
||||
verbosity = 0
|
||||
else:
|
||||
verbosity = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
|
||||
verbosity = config.getoption("verbose") if config is not None else 0
|
||||
if verbosity >= 2:
|
||||
return None
|
||||
if verbosity >= 1:
|
||||
@@ -465,7 +458,7 @@ def _format_assertmsg(obj: object) -> str:
|
||||
# However in either case we want to preserve the newline.
|
||||
replaces = [("\n", "\n~"), ("%", "%%")]
|
||||
if not isinstance(obj, str):
|
||||
obj = saferepr(obj, _get_maxsize_for_saferepr(util._config))
|
||||
obj = saferepr(obj)
|
||||
replaces.append(("\\n", "\n~"))
|
||||
|
||||
for r1, r2 in replaces:
|
||||
@@ -476,8 +469,7 @@ def _format_assertmsg(obj: object) -> str:
|
||||
|
||||
def _should_repr_global_name(obj: object) -> bool:
|
||||
if callable(obj):
|
||||
# For pytest fixtures the __repr__ method provides more information than the function name.
|
||||
return isinstance(obj, FixtureFunctionDefinition)
|
||||
return False
|
||||
|
||||
try:
|
||||
return not hasattr(obj, "__name__")
|
||||
@@ -486,7 +478,7 @@ def _should_repr_global_name(obj: object) -> bool:
|
||||
|
||||
|
||||
def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
|
||||
explanation = "(" + ((is_or and " or ") or " and ").join(explanations) + ")"
|
||||
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
|
||||
return explanation.replace("%", "%%")
|
||||
|
||||
|
||||
@@ -496,7 +488,7 @@ def _call_reprcompare(
|
||||
expls: Sequence[str],
|
||||
each_obj: Sequence[object],
|
||||
) -> str:
|
||||
for i, res, expl in zip(range(len(ops)), results, expls, strict=True):
|
||||
for i, res, expl in zip(range(len(ops)), results, expls):
|
||||
try:
|
||||
done = not res
|
||||
except Exception:
|
||||
@@ -558,14 +550,14 @@ def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_assertion_exprs(src: bytes) -> dict[int, str]:
|
||||
def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
|
||||
"""Return a mapping from {lineno: "assertion test expression"}."""
|
||||
ret: dict[int, str] = {}
|
||||
ret: Dict[int, str] = {}
|
||||
|
||||
depth = 0
|
||||
lines: list[str] = []
|
||||
assert_lineno: int | None = None
|
||||
seen_lines: set[int] = set()
|
||||
lines: List[str] = []
|
||||
assert_lineno: Optional[int] = None
|
||||
seen_lines: Set[int] = set()
|
||||
|
||||
def _write_and_reset() -> None:
|
||||
nonlocal depth, lines, assert_lineno, seen_lines
|
||||
@@ -599,7 +591,7 @@ def _get_assertion_exprs(src: bytes) -> dict[int, str]:
|
||||
# multi-line assert with message
|
||||
elif lineno in seen_lines:
|
||||
lines[-1] = lines[-1][:offset]
|
||||
# multi line assert with escaped newline before message
|
||||
# multi line assert with escapd newline before message
|
||||
else:
|
||||
lines.append(line[:offset])
|
||||
_write_and_reset()
|
||||
@@ -672,7 +664,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, module_path: str | None, config: Config | None, source: bytes
|
||||
self, module_path: Optional[str], config: Optional[Config], source: bytes
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.module_path = module_path
|
||||
@@ -685,9 +677,9 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
self.enable_assertion_pass_hook = False
|
||||
self.source = source
|
||||
self.scope: tuple[ast.AST, ...] = ()
|
||||
self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = (
|
||||
defaultdict(dict)
|
||||
)
|
||||
self.variables_overwrite: defaultdict[
|
||||
tuple[ast.AST, ...], Dict[str, str]
|
||||
] = defaultdict(dict)
|
||||
|
||||
def run(self, mod: ast.Module) -> None:
|
||||
"""Find all assert statements in *mod* and rewrite them."""
|
||||
@@ -702,18 +694,28 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
if doc is not None and self.is_rewrite_disabled(doc):
|
||||
return
|
||||
pos = 0
|
||||
item = None
|
||||
for item in mod.body:
|
||||
match item:
|
||||
case ast.Expr(value=ast.Constant(value=str() as doc)) if (
|
||||
expect_docstring
|
||||
):
|
||||
if self.is_rewrite_disabled(doc):
|
||||
return
|
||||
expect_docstring = False
|
||||
case ast.ImportFrom(level=0, module="__future__"):
|
||||
pass
|
||||
case _:
|
||||
break
|
||||
if (
|
||||
expect_docstring
|
||||
and isinstance(item, ast.Expr)
|
||||
and isinstance(item.value, astStr)
|
||||
):
|
||||
if sys.version_info >= (3, 8):
|
||||
doc = item.value.value
|
||||
else:
|
||||
doc = item.value.s
|
||||
if self.is_rewrite_disabled(doc):
|
||||
return
|
||||
expect_docstring = False
|
||||
elif (
|
||||
isinstance(item, ast.ImportFrom)
|
||||
and item.level == 0
|
||||
and item.module == "__future__"
|
||||
):
|
||||
pass
|
||||
else:
|
||||
break
|
||||
pos += 1
|
||||
# Special case: for a decorated function, set the lineno to that of the
|
||||
# first decorator, not the `def`. Issue #4984.
|
||||
@@ -722,15 +724,21 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
else:
|
||||
lineno = item.lineno
|
||||
# Now actually insert the special imports.
|
||||
aliases = [
|
||||
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
|
||||
ast.alias(
|
||||
"_pytest.assertion.rewrite",
|
||||
"@pytest_ar",
|
||||
lineno=lineno,
|
||||
col_offset=0,
|
||||
),
|
||||
]
|
||||
if sys.version_info >= (3, 10):
|
||||
aliases = [
|
||||
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
|
||||
ast.alias(
|
||||
"_pytest.assertion.rewrite",
|
||||
"@pytest_ar",
|
||||
lineno=lineno,
|
||||
col_offset=0,
|
||||
),
|
||||
]
|
||||
else:
|
||||
aliases = [
|
||||
ast.alias("builtins", "@py_builtins"),
|
||||
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
|
||||
]
|
||||
imports = [
|
||||
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
|
||||
]
|
||||
@@ -738,10 +746,10 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
|
||||
# Collect asserts.
|
||||
self.scope = (mod,)
|
||||
nodes: list[ast.AST | Sentinel] = [mod]
|
||||
nodes: List[Union[ast.AST, Sentinel]] = [mod]
|
||||
while nodes:
|
||||
node = nodes.pop()
|
||||
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
self.scope = tuple((*self.scope, node))
|
||||
nodes.append(_SCOPE_END_MARKER)
|
||||
if node == _SCOPE_END_MARKER:
|
||||
@@ -750,7 +758,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
assert isinstance(node, ast.AST)
|
||||
for name, field in ast.iter_fields(node):
|
||||
if isinstance(field, list):
|
||||
new: list[ast.AST] = []
|
||||
new: List[ast.AST] = []
|
||||
for i, child in enumerate(field):
|
||||
if isinstance(child, ast.Assert):
|
||||
# Transform assert.
|
||||
@@ -783,7 +791,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
"""Give *expr* a name."""
|
||||
name = self.variable()
|
||||
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
|
||||
return ast.copy_location(ast.Name(name, ast.Load()), expr)
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
def display(self, expr: ast.expr) -> ast.expr:
|
||||
"""Call saferepr on the expression."""
|
||||
@@ -822,7 +830,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
to format a string of %-formatted values as added by
|
||||
.explanation_param().
|
||||
"""
|
||||
self.explanation_specifiers: dict[str, ast.expr] = {}
|
||||
self.explanation_specifiers: Dict[str, ast.expr] = {}
|
||||
self.stack.append(self.explanation_specifiers)
|
||||
|
||||
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
|
||||
@@ -836,7 +844,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
current = self.stack.pop()
|
||||
if self.stack:
|
||||
self.explanation_specifiers = self.stack[-1]
|
||||
keys: list[ast.expr | None] = [ast.Constant(key) for key in current.keys()]
|
||||
keys = [astStr(key) for key in current.keys()]
|
||||
format_dict = ast.Dict(keys, list(current.values()))
|
||||
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
|
||||
name = "@py_format" + str(next(self.variable_counter))
|
||||
@@ -845,13 +853,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
def generic_visit(self, node: ast.AST) -> tuple[ast.Name, str]:
|
||||
def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
|
||||
"""Handle expressions we don't have custom code for."""
|
||||
assert isinstance(node, ast.expr)
|
||||
res = self.assign(node)
|
||||
return res, self.explanation_param(self.display(res))
|
||||
|
||||
def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
|
||||
def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
|
||||
"""Return the AST statements to replace the ast.Assert instance.
|
||||
|
||||
This rewrites the test of an assertion to provide
|
||||
@@ -860,9 +868,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
the expression is false.
|
||||
"""
|
||||
if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
|
||||
import warnings
|
||||
|
||||
from _pytest.warning_types import PytestAssertRewriteWarning
|
||||
import warnings
|
||||
|
||||
# TODO: This assert should not be needed.
|
||||
assert self.module_path is not None
|
||||
@@ -875,15 +882,15 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
lineno=assert_.lineno,
|
||||
)
|
||||
|
||||
self.statements: list[ast.stmt] = []
|
||||
self.variables: list[str] = []
|
||||
self.statements: List[ast.stmt] = []
|
||||
self.variables: List[str] = []
|
||||
self.variable_counter = itertools.count()
|
||||
|
||||
if self.enable_assertion_pass_hook:
|
||||
self.format_variables: list[str] = []
|
||||
self.format_variables: List[str] = []
|
||||
|
||||
self.stack: list[dict[str, ast.expr]] = []
|
||||
self.expl_stmts: list[ast.stmt] = []
|
||||
self.stack: List[Dict[str, ast.expr]] = []
|
||||
self.expl_stmts: List[ast.stmt] = []
|
||||
self.push_format_context()
|
||||
# Rewrite assert into a bunch of statements.
|
||||
top_condition, explanation = self.visit(assert_.test)
|
||||
@@ -891,16 +898,16 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
negation = ast.UnaryOp(ast.Not(), top_condition)
|
||||
|
||||
if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
|
||||
msg = self.pop_format_context(ast.Constant(explanation))
|
||||
msg = self.pop_format_context(astStr(explanation))
|
||||
|
||||
# Failed
|
||||
if assert_.msg:
|
||||
assertmsg = self.helper("_format_assertmsg", assert_.msg)
|
||||
gluestr = "\n>assert "
|
||||
else:
|
||||
assertmsg = ast.Constant("")
|
||||
assertmsg = astStr("")
|
||||
gluestr = "assert "
|
||||
err_explanation = ast.BinOp(ast.Constant(gluestr), ast.Add(), msg)
|
||||
err_explanation = ast.BinOp(astStr(gluestr), ast.Add(), msg)
|
||||
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
|
||||
err_name = ast.Name("AssertionError", ast.Load())
|
||||
fmt = self.helper("_format_explanation", err_msg)
|
||||
@@ -916,27 +923,27 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
hook_call_pass = ast.Expr(
|
||||
self.helper(
|
||||
"_call_assertion_pass",
|
||||
ast.Constant(assert_.lineno),
|
||||
ast.Constant(orig),
|
||||
astNum(assert_.lineno),
|
||||
astStr(orig),
|
||||
fmt_pass,
|
||||
)
|
||||
)
|
||||
# If any hooks implement assert_pass hook
|
||||
hook_impl_test = ast.If(
|
||||
self.helper("_check_if_assertion_pass_impl"),
|
||||
[*self.expl_stmts, hook_call_pass],
|
||||
self.expl_stmts + [hook_call_pass],
|
||||
[],
|
||||
)
|
||||
statements_pass: list[ast.stmt] = [hook_impl_test]
|
||||
statements_pass = [hook_impl_test]
|
||||
|
||||
# Test for assertion condition
|
||||
main_test = ast.If(negation, statements_fail, statements_pass)
|
||||
self.statements.append(main_test)
|
||||
if self.format_variables:
|
||||
variables: list[ast.expr] = [
|
||||
variables = [
|
||||
ast.Name(name, ast.Store()) for name in self.format_variables
|
||||
]
|
||||
clear_format = ast.Assign(variables, ast.Constant(None))
|
||||
clear_format = ast.Assign(variables, astNameConstant(None))
|
||||
self.statements.append(clear_format)
|
||||
|
||||
else: # Original assertion rewriting
|
||||
@@ -947,9 +954,9 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
assertmsg = self.helper("_format_assertmsg", assert_.msg)
|
||||
explanation = "\n>assert " + explanation
|
||||
else:
|
||||
assertmsg = ast.Constant("")
|
||||
assertmsg = astStr("")
|
||||
explanation = "assert " + explanation
|
||||
template = ast.BinOp(assertmsg, ast.Add(), ast.Constant(explanation))
|
||||
template = ast.BinOp(assertmsg, ast.Add(), astStr(explanation))
|
||||
msg = self.pop_format_context(template)
|
||||
fmt = self.helper("_format_explanation", msg)
|
||||
err_name = ast.Name("AssertionError", ast.Load())
|
||||
@@ -961,40 +968,37 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
# Clear temporary variables by setting them to None.
|
||||
if self.variables:
|
||||
variables = [ast.Name(name, ast.Store()) for name in self.variables]
|
||||
clear = ast.Assign(variables, ast.Constant(None))
|
||||
clear = ast.Assign(variables, astNameConstant(None))
|
||||
self.statements.append(clear)
|
||||
# Fix locations (line numbers/column offsets).
|
||||
for stmt in self.statements:
|
||||
for node in traverse_node(stmt):
|
||||
if getattr(node, "lineno", None) is None:
|
||||
# apply the assertion location to all generated ast nodes without source location
|
||||
# and preserve the location of existing nodes or generated nodes with an correct location.
|
||||
ast.copy_location(node, assert_)
|
||||
ast.copy_location(node, assert_)
|
||||
return self.statements
|
||||
|
||||
def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
|
||||
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
|
||||
# This method handles the 'walrus operator' repr of the target
|
||||
# name if it's a local variable or _should_repr_global_name()
|
||||
# thinks it's acceptable.
|
||||
locs = ast.Call(self.builtin("locals"), [], [])
|
||||
target_id = name.target.id
|
||||
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
|
||||
target_id = name.target.id # type: ignore[attr-defined]
|
||||
inlocs = ast.Compare(astStr(target_id), [ast.In()], [locs])
|
||||
dorepr = self.helper("_should_repr_global_name", name)
|
||||
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
|
||||
expr = ast.IfExp(test, self.display(name), ast.Constant(target_id))
|
||||
expr = ast.IfExp(test, self.display(name), astStr(target_id))
|
||||
return name, self.explanation_param(expr)
|
||||
|
||||
def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]:
|
||||
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
|
||||
# Display the repr of the name if it's a local variable or
|
||||
# _should_repr_global_name() thinks it's acceptable.
|
||||
locs = ast.Call(self.builtin("locals"), [], [])
|
||||
inlocs = ast.Compare(ast.Constant(name.id), [ast.In()], [locs])
|
||||
inlocs = ast.Compare(astStr(name.id), [ast.In()], [locs])
|
||||
dorepr = self.helper("_should_repr_global_name", name)
|
||||
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
|
||||
expr = ast.IfExp(test, self.display(name), ast.Constant(name.id))
|
||||
expr = ast.IfExp(test, self.display(name), astStr(name.id))
|
||||
return name, self.explanation_param(expr)
|
||||
|
||||
def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
|
||||
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
|
||||
res_var = self.variable()
|
||||
expl_list = self.assign(ast.List([], ast.Load()))
|
||||
app = ast.Attribute(expl_list, "append", ast.Load())
|
||||
@@ -1006,57 +1010,60 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
# Process each operand, short-circuiting if needed.
|
||||
for i, v in enumerate(boolop.values):
|
||||
if i:
|
||||
fail_inner: list[ast.stmt] = []
|
||||
fail_inner: List[ast.stmt] = []
|
||||
# cond is set in a prior loop iteration below
|
||||
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821
|
||||
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
||||
self.expl_stmts = fail_inner
|
||||
match v:
|
||||
# Check if the left operand is an ast.NamedExpr and the value has already been visited
|
||||
case ast.Compare(
|
||||
left=ast.NamedExpr(target=ast.Name(id=target_id))
|
||||
) if target_id in [
|
||||
e.id for e in boolop.values[:i] if hasattr(e, "id")
|
||||
]:
|
||||
pytest_temp = self.variable()
|
||||
self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment]
|
||||
# mypy's false positive, we're checking that the 'target' attribute exists.
|
||||
v.left.target.id = pytest_temp # type:ignore[attr-defined]
|
||||
# Check if the left operand is a namedExpr and the value has already been visited
|
||||
if (
|
||||
isinstance(v, ast.Compare)
|
||||
and isinstance(v.left, namedExpr)
|
||||
and v.left.target.id
|
||||
in [
|
||||
ast_expr.id
|
||||
for ast_expr in boolop.values[:i]
|
||||
if hasattr(ast_expr, "id")
|
||||
]
|
||||
):
|
||||
pytest_temp = self.variable()
|
||||
self.variables_overwrite[self.scope][
|
||||
v.left.target.id
|
||||
] = v.left # type:ignore[assignment]
|
||||
v.left.target.id = pytest_temp
|
||||
self.push_format_context()
|
||||
res, expl = self.visit(v)
|
||||
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
||||
expl_format = self.pop_format_context(ast.Constant(expl))
|
||||
expl_format = self.pop_format_context(astStr(expl))
|
||||
call = ast.Call(app, [expl_format], [])
|
||||
self.expl_stmts.append(ast.Expr(call))
|
||||
if i < levels:
|
||||
cond: ast.expr = res
|
||||
if is_or:
|
||||
cond = ast.UnaryOp(ast.Not(), cond)
|
||||
inner: list[ast.stmt] = []
|
||||
inner: List[ast.stmt] = []
|
||||
self.statements.append(ast.If(cond, inner, []))
|
||||
self.statements = body = inner
|
||||
self.statements = save
|
||||
self.expl_stmts = fail_save
|
||||
expl_template = self.helper("_format_boolop", expl_list, ast.Constant(is_or))
|
||||
expl_template = self.helper("_format_boolop", expl_list, astNum(is_or))
|
||||
expl = self.pop_format_context(expl_template)
|
||||
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
||||
|
||||
def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]:
|
||||
def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
|
||||
pattern = UNARY_MAP[unary.op.__class__]
|
||||
operand_res, operand_expl = self.visit(unary.operand)
|
||||
res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary))
|
||||
res = self.assign(ast.UnaryOp(unary.op, operand_res))
|
||||
return res, pattern % (operand_expl,)
|
||||
|
||||
def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
|
||||
def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
|
||||
symbol = BINOP_MAP[binop.op.__class__]
|
||||
left_expr, left_expl = self.visit(binop.left)
|
||||
right_expr, right_expl = self.visit(binop.right)
|
||||
explanation = f"({left_expl} {symbol} {right_expl})"
|
||||
res = self.assign(
|
||||
ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop)
|
||||
)
|
||||
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
|
||||
return res, explanation
|
||||
|
||||
def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
|
||||
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
||||
new_func, func_expl = self.visit(call.func)
|
||||
arg_expls = []
|
||||
new_args = []
|
||||
@@ -1065,16 +1072,19 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
|
||||
self.scope, {}
|
||||
):
|
||||
arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment]
|
||||
arg = self.variables_overwrite[self.scope][
|
||||
arg.id
|
||||
] # type:ignore[assignment]
|
||||
res, expl = self.visit(arg)
|
||||
arg_expls.append(expl)
|
||||
new_args.append(res)
|
||||
for keyword in call.keywords:
|
||||
match keyword.value:
|
||||
case ast.Name(id=id) if id in self.variables_overwrite.get(
|
||||
self.scope, {}
|
||||
):
|
||||
keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment]
|
||||
if isinstance(
|
||||
keyword.value, ast.Name
|
||||
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
|
||||
keyword.value = self.variables_overwrite[self.scope][
|
||||
keyword.value.id
|
||||
] # type:ignore[assignment]
|
||||
res, expl = self.visit(keyword.value)
|
||||
new_kwargs.append(ast.keyword(keyword.arg, res))
|
||||
if keyword.arg:
|
||||
@@ -1083,68 +1093,70 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
arg_expls.append("**" + expl)
|
||||
|
||||
expl = "{}({})".format(func_expl, ", ".join(arg_expls))
|
||||
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call)
|
||||
new_call = ast.Call(new_func, new_args, new_kwargs)
|
||||
res = self.assign(new_call)
|
||||
res_expl = self.explanation_param(self.display(res))
|
||||
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
|
||||
return res, outer_expl
|
||||
|
||||
def visit_Starred(self, starred: ast.Starred) -> tuple[ast.Starred, str]:
|
||||
def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
|
||||
# A Starred node can appear in a function call.
|
||||
res, expl = self.visit(starred.value)
|
||||
new_starred = ast.Starred(res, starred.ctx)
|
||||
return new_starred, "*" + expl
|
||||
|
||||
def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
|
||||
def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
|
||||
if not isinstance(attr.ctx, ast.Load):
|
||||
return self.generic_visit(attr)
|
||||
value, value_expl = self.visit(attr.value)
|
||||
res = self.assign(
|
||||
ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr)
|
||||
)
|
||||
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
|
||||
res_expl = self.explanation_param(self.display(res))
|
||||
pat = "%s\n{%s = %s.%s\n}"
|
||||
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
|
||||
return res, expl
|
||||
|
||||
def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
|
||||
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
||||
self.push_format_context()
|
||||
# We first check if we have overwritten a variable in the previous assert
|
||||
match comp.left:
|
||||
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
|
||||
self.scope, {}
|
||||
):
|
||||
comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment]
|
||||
case ast.NamedExpr(target=ast.Name(id=target_id)):
|
||||
self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment]
|
||||
if isinstance(
|
||||
comp.left, ast.Name
|
||||
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
|
||||
comp.left = self.variables_overwrite[self.scope][
|
||||
comp.left.id
|
||||
] # type:ignore[assignment]
|
||||
if isinstance(comp.left, namedExpr):
|
||||
self.variables_overwrite[self.scope][
|
||||
comp.left.target.id
|
||||
] = comp.left # type:ignore[assignment]
|
||||
left_res, left_expl = self.visit(comp.left)
|
||||
if isinstance(comp.left, ast.Compare | ast.BoolOp):
|
||||
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
||||
left_expl = f"({left_expl})"
|
||||
res_variables = [self.variable() for i in range(len(comp.ops))]
|
||||
load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables]
|
||||
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
|
||||
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
|
||||
it = zip(range(len(comp.ops)), comp.ops, comp.comparators, strict=True)
|
||||
expls: list[ast.expr] = []
|
||||
syms: list[ast.expr] = []
|
||||
it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
|
||||
expls = []
|
||||
syms = []
|
||||
results = [left_res]
|
||||
for i, op, next_operand in it:
|
||||
match (next_operand, left_res):
|
||||
case (
|
||||
ast.NamedExpr(target=ast.Name(id=target_id)),
|
||||
ast.Name(id=name_id),
|
||||
) if target_id == name_id:
|
||||
next_operand.target.id = self.variable()
|
||||
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]
|
||||
|
||||
if (
|
||||
isinstance(next_operand, namedExpr)
|
||||
and isinstance(left_res, ast.Name)
|
||||
and next_operand.target.id == left_res.id
|
||||
):
|
||||
next_operand.target.id = self.variable()
|
||||
self.variables_overwrite[self.scope][
|
||||
left_res.id
|
||||
] = next_operand # type:ignore[assignment]
|
||||
next_res, next_expl = self.visit(next_operand)
|
||||
if isinstance(next_operand, ast.Compare | ast.BoolOp):
|
||||
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
|
||||
next_expl = f"({next_expl})"
|
||||
results.append(next_res)
|
||||
sym = BINOP_MAP[op.__class__]
|
||||
syms.append(ast.Constant(sym))
|
||||
syms.append(astStr(sym))
|
||||
expl = f"{left_expl} {sym} {next_expl}"
|
||||
expls.append(ast.Constant(expl))
|
||||
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp)
|
||||
expls.append(astStr(expl))
|
||||
res_expr = ast.Compare(left_res, [op], [next_res])
|
||||
self.statements.append(ast.Assign([store_names[i]], res_expr))
|
||||
left_res, left_expl = next_res, next_expl
|
||||
# Use pytest.assertion.util._reprcompare if that's available.
|
||||
@@ -1179,10 +1191,7 @@ def try_makedirs(cache_dir: Path) -> bool:
|
||||
return False
|
||||
except OSError as e:
|
||||
# as of now, EROFS doesn't have an equivalent OSError-subclass
|
||||
#
|
||||
# squashfuse_ll returns ENOSYS "OSError: [Errno 38] Function not
|
||||
# implemented" for a read-only error
|
||||
if e.errno in {errno.EROFS, errno.ENOSYS}:
|
||||
if e.errno == errno.EROFS:
|
||||
return False
|
||||
raise
|
||||
return True
|
||||
@@ -1190,7 +1199,7 @@ def try_makedirs(cache_dir: Path) -> bool:
|
||||
|
||||
def get_cache_dir(file_path: Path) -> Path:
|
||||
"""Return the cache directory to write .pyc files for the given .py file path."""
|
||||
if sys.pycache_prefix:
|
||||
if sys.version_info >= (3, 8) and sys.pycache_prefix:
|
||||
# given:
|
||||
# prefix = '/tmp/pycs'
|
||||
# path = '/home/user/proj/test_app.py'
|
||||
|
||||
@@ -1,65 +1,51 @@
|
||||
"""Utilities for truncating assertion output.
|
||||
|
||||
Current default behaviour is to truncate assertion explanations at
|
||||
terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI.
|
||||
~8 terminal lines, unless running in "-vv" mode or running on CI.
|
||||
"""
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from _pytest.compat import running_on_ci
|
||||
from _pytest.config import Config
|
||||
from _pytest.assertion import util
|
||||
from _pytest.nodes import Item
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 8
|
||||
DEFAULT_MAX_CHARS = DEFAULT_MAX_LINES * 80
|
||||
DEFAULT_MAX_CHARS = 8 * 80
|
||||
USAGE_MSG = "use '-vv' to show"
|
||||
|
||||
|
||||
def truncate_if_required(explanation: list[str], item: Item) -> list[str]:
|
||||
def truncate_if_required(
|
||||
explanation: List[str], item: Item, max_length: Optional[int] = None
|
||||
) -> List[str]:
|
||||
"""Truncate this assertion explanation if the given test item is eligible."""
|
||||
should_truncate, max_lines, max_chars = _get_truncation_parameters(item)
|
||||
if should_truncate:
|
||||
return _truncate_explanation(
|
||||
explanation,
|
||||
max_lines=max_lines,
|
||||
max_chars=max_chars,
|
||||
)
|
||||
if _should_truncate_item(item):
|
||||
return _truncate_explanation(explanation)
|
||||
return explanation
|
||||
|
||||
|
||||
def _get_truncation_parameters(item: Item) -> tuple[bool, int, int]:
|
||||
"""Return the truncation parameters related to the given item, as (should truncate, max lines, max chars)."""
|
||||
# We do not need to truncate if one of conditions is met:
|
||||
# 1. Verbosity level is 2 or more;
|
||||
# 2. Test is being run in CI environment;
|
||||
# 3. Both truncation_limit_lines and truncation_limit_chars
|
||||
# .ini parameters are set to 0 explicitly.
|
||||
max_lines = item.config.getini("truncation_limit_lines")
|
||||
max_lines = int(max_lines if max_lines is not None else DEFAULT_MAX_LINES)
|
||||
|
||||
max_chars = item.config.getini("truncation_limit_chars")
|
||||
max_chars = int(max_chars if max_chars is not None else DEFAULT_MAX_CHARS)
|
||||
|
||||
verbose = item.config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
|
||||
|
||||
should_truncate = verbose < 2 and not running_on_ci()
|
||||
should_truncate = should_truncate and (max_lines > 0 or max_chars > 0)
|
||||
|
||||
return should_truncate, max_lines, max_chars
|
||||
def _should_truncate_item(item: Item) -> bool:
|
||||
"""Whether or not this test item is eligible for truncation."""
|
||||
verbose = item.config.option.verbose
|
||||
return verbose < 2 and not util.running_on_ci()
|
||||
|
||||
|
||||
def _truncate_explanation(
|
||||
input_lines: list[str],
|
||||
max_lines: int,
|
||||
max_chars: int,
|
||||
) -> list[str]:
|
||||
input_lines: List[str],
|
||||
max_lines: Optional[int] = None,
|
||||
max_chars: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""Truncate given list of strings that makes up the assertion explanation.
|
||||
|
||||
Truncates to either max_lines, or max_chars - whichever the input reaches
|
||||
Truncates to either 8 lines, or 640 characters - whichever the input reaches
|
||||
first, taking the truncation explanation into account. The remaining lines
|
||||
will be replaced by a usage message.
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_chars is None:
|
||||
max_chars = DEFAULT_MAX_CHARS
|
||||
|
||||
# Check if truncation required
|
||||
input_char_count = len("".join(input_lines))
|
||||
# The length of the truncation explanation depends on the number of lines
|
||||
@@ -84,23 +70,16 @@ def _truncate_explanation(
|
||||
):
|
||||
return input_lines
|
||||
# Truncate first to max_lines, and then truncate to max_chars if necessary
|
||||
if max_lines > 0:
|
||||
truncated_explanation = input_lines[:max_lines]
|
||||
else:
|
||||
truncated_explanation = input_lines
|
||||
truncated_explanation = input_lines[:max_lines]
|
||||
truncated_char = True
|
||||
# We reevaluate the need to truncate chars following removal of some lines
|
||||
if len("".join(truncated_explanation)) > tolerable_max_chars and max_chars > 0:
|
||||
if len("".join(truncated_explanation)) > tolerable_max_chars:
|
||||
truncated_explanation = _truncate_by_char_count(
|
||||
truncated_explanation, max_chars
|
||||
)
|
||||
else:
|
||||
truncated_char = False
|
||||
|
||||
if truncated_explanation == input_lines:
|
||||
# No truncation happened, so we do not need to add any explanations
|
||||
return truncated_explanation
|
||||
|
||||
truncated_line_count = len(input_lines) - len(truncated_explanation)
|
||||
if truncated_explanation[-1]:
|
||||
# Add ellipsis and take into account part-truncated final line
|
||||
@@ -111,15 +90,14 @@ def _truncate_explanation(
|
||||
else:
|
||||
# Add proper ellipsis when we were able to fit a full line exactly
|
||||
truncated_explanation[-1] = "..."
|
||||
return [
|
||||
*truncated_explanation,
|
||||
return truncated_explanation + [
|
||||
"",
|
||||
f"...Full output truncated ({truncated_line_count} line"
|
||||
f"{'' if truncated_line_count == 1 else 's'} hidden), {USAGE_MSG}",
|
||||
]
|
||||
|
||||
|
||||
def _truncate_by_char_count(input_lines: list[str], max_chars: int) -> list[str]:
|
||||
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
|
||||
# Find point at which input length exceeds total allowed length
|
||||
iterated_char_count = 0
|
||||
for iterated_index, input_line in enumerate(input_lines):
|
||||
|
||||
@@ -1,54 +1,36 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Utilities for assertion debugging."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
import os
|
||||
import pprint
|
||||
from typing import AbstractSet
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Protocol
|
||||
from typing import Callable
|
||||
from typing import Iterable
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from unicodedata import normalize
|
||||
|
||||
from _pytest import outcomes
|
||||
import _pytest._code
|
||||
from _pytest._io.pprint import PrettyPrinter
|
||||
from _pytest import outcomes
|
||||
from _pytest._io.saferepr import _pformat_dispatch
|
||||
from _pytest._io.saferepr import saferepr
|
||||
from _pytest._io.saferepr import saferepr_unlimited
|
||||
from _pytest.compat import running_on_ci
|
||||
from _pytest.config import Config
|
||||
|
||||
|
||||
# The _reprcompare attribute on the util module is used by the new assertion
|
||||
# interpretation code and assertion rewriter to detect this plugin was
|
||||
# loaded and in turn call the hooks defined here as part of the
|
||||
# DebugInterpreter.
|
||||
_reprcompare: Callable[[str, object, object], str | None] | None = None
|
||||
_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None
|
||||
|
||||
# Works similarly as _reprcompare attribute. Is populated with the hook call
|
||||
# when pytest_runtest_setup is called.
|
||||
_assertion_pass: Callable[[int, str, str], None] | None = None
|
||||
_assertion_pass: Optional[Callable[[int, str, str], None]] = None
|
||||
|
||||
# Config object which is assigned during pytest_runtest_protocol.
|
||||
_config: Config | None = None
|
||||
|
||||
|
||||
class _HighlightFunc(Protocol):
|
||||
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
|
||||
"""Apply highlighting to the given source."""
|
||||
|
||||
|
||||
def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str:
|
||||
"""Dummy highlighter that returns the text unprocessed.
|
||||
|
||||
Needed for _notin_text, as the diff gets post-processed to only show the "+" part.
|
||||
"""
|
||||
return source
|
||||
_config: Optional[Config] = None
|
||||
|
||||
|
||||
def format_explanation(explanation: str) -> str:
|
||||
@@ -66,7 +48,7 @@ def format_explanation(explanation: str) -> str:
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def _split_explanation(explanation: str) -> list[str]:
|
||||
def _split_explanation(explanation: str) -> List[str]:
|
||||
r"""Return a list of individual lines in the explanation.
|
||||
|
||||
This will return a list of lines split on '\n{', '\n}' and '\n~'.
|
||||
@@ -83,7 +65,7 @@ def _split_explanation(explanation: str) -> list[str]:
|
||||
return lines
|
||||
|
||||
|
||||
def _format_lines(lines: Sequence[str]) -> list[str]:
|
||||
def _format_lines(lines: Sequence[str]) -> List[str]:
|
||||
"""Format the individual lines.
|
||||
|
||||
This will replace the '{', '}' and '~' characters of our mini formatting
|
||||
@@ -131,7 +113,7 @@ def isdict(x: Any) -> bool:
|
||||
|
||||
|
||||
def isset(x: Any) -> bool:
|
||||
return isinstance(x, set | frozenset)
|
||||
return isinstance(x, (set, frozenset))
|
||||
|
||||
|
||||
def isnamedtuple(obj: Any) -> bool:
|
||||
@@ -150,7 +132,7 @@ def isiterable(obj: Any) -> bool:
|
||||
try:
|
||||
iter(obj)
|
||||
return not istext(obj)
|
||||
except Exception:
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
@@ -169,7 +151,7 @@ def has_default_eq(
|
||||
code_filename = obj.__eq__.__code__.co_filename
|
||||
|
||||
if isattrs(obj):
|
||||
return "attrs generated " in code_filename
|
||||
return "attrs generated eq" in code_filename
|
||||
|
||||
return code_filename == "<string>" # data class
|
||||
return True
|
||||
@@ -177,9 +159,9 @@ def has_default_eq(
|
||||
|
||||
def assertrepr_compare(
|
||||
config, op: str, left: Any, right: Any, use_ascii: bool = False
|
||||
) -> list[str] | None:
|
||||
) -> Optional[List[str]]:
|
||||
"""Return specialised explanations for some operators/operands."""
|
||||
verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
|
||||
verbose = config.getoption("verbose")
|
||||
|
||||
# Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier.
|
||||
# See issue #3246.
|
||||
@@ -203,54 +185,34 @@ def assertrepr_compare(
|
||||
right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii)
|
||||
|
||||
summary = f"{left_repr} {op} {right_repr}"
|
||||
highlighter = config.get_terminal_writer()._highlight
|
||||
|
||||
explanation = None
|
||||
try:
|
||||
if op == "==":
|
||||
explanation = _compare_eq_any(left, right, highlighter, verbose)
|
||||
explanation = _compare_eq_any(left, right, verbose)
|
||||
elif op == "not in":
|
||||
if istext(left) and istext(right):
|
||||
explanation = _notin_text(left, right, verbose)
|
||||
elif op == "!=":
|
||||
if isset(left) and isset(right):
|
||||
explanation = ["Both sets are equal"]
|
||||
elif op == ">=":
|
||||
if isset(left) and isset(right):
|
||||
explanation = _compare_gte_set(left, right, highlighter, verbose)
|
||||
elif op == "<=":
|
||||
if isset(left) and isset(right):
|
||||
explanation = _compare_lte_set(left, right, highlighter, verbose)
|
||||
elif op == ">":
|
||||
if isset(left) and isset(right):
|
||||
explanation = _compare_gt_set(left, right, highlighter, verbose)
|
||||
elif op == "<":
|
||||
if isset(left) and isset(right):
|
||||
explanation = _compare_lt_set(left, right, highlighter, verbose)
|
||||
|
||||
except outcomes.Exit:
|
||||
raise
|
||||
except Exception:
|
||||
repr_crash = _pytest._code.ExceptionInfo.from_current()._getreprcrash()
|
||||
explanation = [
|
||||
f"(pytest_assertion plugin: representation of details failed: {repr_crash}.",
|
||||
"(pytest_assertion plugin: representation of details failed: {}.".format(
|
||||
_pytest._code.ExceptionInfo.from_current()._getreprcrash()
|
||||
),
|
||||
" Probably an object has a faulty __repr__.)",
|
||||
]
|
||||
|
||||
if not explanation:
|
||||
return None
|
||||
|
||||
if explanation[0] != "":
|
||||
explanation = ["", *explanation]
|
||||
return [summary, *explanation]
|
||||
return [summary] + explanation
|
||||
|
||||
|
||||
def _compare_eq_any(
|
||||
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
|
||||
) -> list[str]:
|
||||
def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
|
||||
explanation = []
|
||||
if istext(left) and istext(right):
|
||||
explanation = _diff_text(left, right, highlighter, verbose)
|
||||
explanation = _diff_text(left, right, verbose)
|
||||
else:
|
||||
from _pytest.python_api import ApproxBase
|
||||
|
||||
@@ -260,31 +222,29 @@ def _compare_eq_any(
|
||||
other_side = right if isinstance(left, ApproxBase) else left
|
||||
|
||||
explanation = approx_side._repr_compare(other_side)
|
||||
elif type(left) is type(right) and (
|
||||
elif type(left) == type(right) and (
|
||||
isdatacls(left) or isattrs(left) or isnamedtuple(left)
|
||||
):
|
||||
# Note: unlike dataclasses/attrs, namedtuples compare only the
|
||||
# field values, not the type or field names. But this branch
|
||||
# intentionally only handles the same-type case, which was often
|
||||
# used in older code bases before dataclasses/attrs were available.
|
||||
explanation = _compare_eq_cls(left, right, highlighter, verbose)
|
||||
explanation = _compare_eq_cls(left, right, verbose)
|
||||
elif issequence(left) and issequence(right):
|
||||
explanation = _compare_eq_sequence(left, right, highlighter, verbose)
|
||||
explanation = _compare_eq_sequence(left, right, verbose)
|
||||
elif isset(left) and isset(right):
|
||||
explanation = _compare_eq_set(left, right, highlighter, verbose)
|
||||
explanation = _compare_eq_set(left, right, verbose)
|
||||
elif isdict(left) and isdict(right):
|
||||
explanation = _compare_eq_dict(left, right, highlighter, verbose)
|
||||
explanation = _compare_eq_dict(left, right, verbose)
|
||||
|
||||
if isiterable(left) and isiterable(right):
|
||||
expl = _compare_eq_iterable(left, right, highlighter, verbose)
|
||||
expl = _compare_eq_iterable(left, right, verbose)
|
||||
explanation.extend(expl)
|
||||
|
||||
return explanation
|
||||
|
||||
|
||||
def _diff_text(
|
||||
left: str, right: str, highlighter: _HighlightFunc, verbose: int = 0
|
||||
) -> list[str]:
|
||||
def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
|
||||
"""Return the explanation for the diff between text.
|
||||
|
||||
Unless --verbose is used this will skip leading and trailing
|
||||
@@ -292,7 +252,7 @@ def _diff_text(
|
||||
"""
|
||||
from difflib import ndiff
|
||||
|
||||
explanation: list[str] = []
|
||||
explanation: List[str] = []
|
||||
|
||||
if verbose < 1:
|
||||
i = 0 # just in case left or right has zero length
|
||||
@@ -302,7 +262,7 @@ def _diff_text(
|
||||
if i > 42:
|
||||
i -= 10 # Provide some context
|
||||
explanation = [
|
||||
f"Skipping {i} identical leading characters in diff, use -v to show"
|
||||
"Skipping %s identical leading characters in diff, use -v to show" % i
|
||||
]
|
||||
left = left[i:]
|
||||
right = right[i:]
|
||||
@@ -313,8 +273,8 @@ def _diff_text(
|
||||
if i > 42:
|
||||
i -= 10 # Provide some context
|
||||
explanation += [
|
||||
f"Skipping {i} identical trailing "
|
||||
"characters in diff, use -v to show"
|
||||
"Skipping {} identical trailing "
|
||||
"characters in diff, use -v to show".format(i)
|
||||
]
|
||||
left = left[:-i]
|
||||
right = right[:-i]
|
||||
@@ -325,55 +285,61 @@ def _diff_text(
|
||||
explanation += ["Strings contain only whitespace, escaping them using repr()"]
|
||||
# "right" is the expected base against which we compare "left",
|
||||
# see https://github.com/pytest-dev/pytest/issues/3333
|
||||
explanation.extend(
|
||||
highlighter(
|
||||
"\n".join(
|
||||
line.strip("\n")
|
||||
for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
|
||||
),
|
||||
lexer="diff",
|
||||
).splitlines()
|
||||
)
|
||||
explanation += [
|
||||
line.strip("\n")
|
||||
for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
|
||||
]
|
||||
return explanation
|
||||
|
||||
|
||||
def _surrounding_parens_on_own_lines(lines: List[str]) -> None:
|
||||
"""Move opening/closing parenthesis/bracket to own lines."""
|
||||
opening = lines[0][:1]
|
||||
if opening in ["(", "[", "{"]:
|
||||
lines[0] = " " + lines[0][1:]
|
||||
lines[:] = [opening] + lines
|
||||
closing = lines[-1][-1:]
|
||||
if closing in [")", "]", "}"]:
|
||||
lines[-1] = lines[-1][:-1] + ","
|
||||
lines[:] = lines + [closing]
|
||||
|
||||
|
||||
def _compare_eq_iterable(
|
||||
left: Iterable[Any],
|
||||
right: Iterable[Any],
|
||||
highlighter: _HighlightFunc,
|
||||
verbose: int = 0,
|
||||
) -> list[str]:
|
||||
left: Iterable[Any], right: Iterable[Any], verbose: int = 0
|
||||
) -> List[str]:
|
||||
if verbose <= 0 and not running_on_ci():
|
||||
return ["Use -v to get more diff"]
|
||||
# dynamic import to speedup pytest
|
||||
import difflib
|
||||
|
||||
left_formatting = PrettyPrinter().pformat(left).splitlines()
|
||||
right_formatting = PrettyPrinter().pformat(right).splitlines()
|
||||
left_formatting = pprint.pformat(left).splitlines()
|
||||
right_formatting = pprint.pformat(right).splitlines()
|
||||
|
||||
explanation = ["", "Full diff:"]
|
||||
# Re-format for different output lengths.
|
||||
lines_left = len(left_formatting)
|
||||
lines_right = len(right_formatting)
|
||||
if lines_left != lines_right:
|
||||
left_formatting = _pformat_dispatch(left).splitlines()
|
||||
right_formatting = _pformat_dispatch(right).splitlines()
|
||||
|
||||
if lines_left > 1 or lines_right > 1:
|
||||
_surrounding_parens_on_own_lines(left_formatting)
|
||||
_surrounding_parens_on_own_lines(right_formatting)
|
||||
|
||||
explanation = ["Full diff:"]
|
||||
# "right" is the expected base against which we compare "left",
|
||||
# see https://github.com/pytest-dev/pytest/issues/3333
|
||||
explanation.extend(
|
||||
highlighter(
|
||||
"\n".join(
|
||||
line.rstrip()
|
||||
for line in difflib.ndiff(right_formatting, left_formatting)
|
||||
),
|
||||
lexer="diff",
|
||||
).splitlines()
|
||||
line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting)
|
||||
)
|
||||
return explanation
|
||||
|
||||
|
||||
def _compare_eq_sequence(
|
||||
left: Sequence[Any],
|
||||
right: Sequence[Any],
|
||||
highlighter: _HighlightFunc,
|
||||
verbose: int = 0,
|
||||
) -> list[str]:
|
||||
left: Sequence[Any], right: Sequence[Any], verbose: int = 0
|
||||
) -> List[str]:
|
||||
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
|
||||
explanation: list[str] = []
|
||||
explanation: List[str] = []
|
||||
len_left = len(left)
|
||||
len_right = len(right)
|
||||
for i in range(min(len_left, len_right)):
|
||||
@@ -393,10 +359,7 @@ def _compare_eq_sequence(
|
||||
left_value = left[i]
|
||||
right_value = right[i]
|
||||
|
||||
explanation.append(
|
||||
f"At index {i} diff:"
|
||||
f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}"
|
||||
)
|
||||
explanation += [f"At index {i} diff: {left_value!r} != {right_value!r}"]
|
||||
break
|
||||
|
||||
if comparing_bytes:
|
||||
@@ -416,134 +379,74 @@ def _compare_eq_sequence(
|
||||
extra = saferepr(right[len_left])
|
||||
|
||||
if len_diff == 1:
|
||||
explanation += [
|
||||
f"{dir_with_more} contains one more item: {highlighter(extra)}"
|
||||
]
|
||||
explanation += [f"{dir_with_more} contains one more item: {extra}"]
|
||||
else:
|
||||
explanation += [
|
||||
f"{dir_with_more} contains {len_diff} more items, first extra item: {highlighter(extra)}"
|
||||
"%s contains %d more items, first extra item: %s"
|
||||
% (dir_with_more, len_diff, extra)
|
||||
]
|
||||
return explanation
|
||||
|
||||
|
||||
def _compare_eq_set(
|
||||
left: AbstractSet[Any],
|
||||
right: AbstractSet[Any],
|
||||
highlighter: _HighlightFunc,
|
||||
verbose: int = 0,
|
||||
) -> list[str]:
|
||||
left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0
|
||||
) -> List[str]:
|
||||
explanation = []
|
||||
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
|
||||
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
|
||||
return explanation
|
||||
|
||||
|
||||
def _compare_gt_set(
|
||||
left: AbstractSet[Any],
|
||||
right: AbstractSet[Any],
|
||||
highlighter: _HighlightFunc,
|
||||
verbose: int = 0,
|
||||
) -> list[str]:
|
||||
explanation = _compare_gte_set(left, right, highlighter)
|
||||
if not explanation:
|
||||
return ["Both sets are equal"]
|
||||
return explanation
|
||||
|
||||
|
||||
def _compare_lt_set(
|
||||
left: AbstractSet[Any],
|
||||
right: AbstractSet[Any],
|
||||
highlighter: _HighlightFunc,
|
||||
verbose: int = 0,
|
||||
) -> list[str]:
|
||||
explanation = _compare_lte_set(left, right, highlighter)
|
||||
if not explanation:
|
||||
return ["Both sets are equal"]
|
||||
return explanation
|
||||
|
||||
|
||||
def _compare_gte_set(
|
||||
left: AbstractSet[Any],
|
||||
right: AbstractSet[Any],
|
||||
highlighter: _HighlightFunc,
|
||||
verbose: int = 0,
|
||||
) -> list[str]:
|
||||
return _set_one_sided_diff("right", right, left, highlighter)
|
||||
|
||||
|
||||
def _compare_lte_set(
|
||||
left: AbstractSet[Any],
|
||||
right: AbstractSet[Any],
|
||||
highlighter: _HighlightFunc,
|
||||
verbose: int = 0,
|
||||
) -> list[str]:
|
||||
return _set_one_sided_diff("left", left, right, highlighter)
|
||||
|
||||
|
||||
def _set_one_sided_diff(
|
||||
posn: str,
|
||||
set1: AbstractSet[Any],
|
||||
set2: AbstractSet[Any],
|
||||
highlighter: _HighlightFunc,
|
||||
) -> list[str]:
|
||||
explanation = []
|
||||
diff = set1 - set2
|
||||
if diff:
|
||||
explanation.append(f"Extra items in the {posn} set:")
|
||||
for item in diff:
|
||||
explanation.append(highlighter(saferepr(item)))
|
||||
diff_left = left - right
|
||||
diff_right = right - left
|
||||
if diff_left:
|
||||
explanation.append("Extra items in the left set:")
|
||||
for item in diff_left:
|
||||
explanation.append(saferepr(item))
|
||||
if diff_right:
|
||||
explanation.append("Extra items in the right set:")
|
||||
for item in diff_right:
|
||||
explanation.append(saferepr(item))
|
||||
return explanation
|
||||
|
||||
|
||||
def _compare_eq_dict(
|
||||
left: Mapping[Any, Any],
|
||||
right: Mapping[Any, Any],
|
||||
highlighter: _HighlightFunc,
|
||||
verbose: int = 0,
|
||||
) -> list[str]:
|
||||
explanation: list[str] = []
|
||||
left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
|
||||
) -> List[str]:
|
||||
explanation: List[str] = []
|
||||
set_left = set(left)
|
||||
set_right = set(right)
|
||||
common = set_left.intersection(set_right)
|
||||
same = {k: left[k] for k in common if left[k] == right[k]}
|
||||
if same and verbose < 2:
|
||||
explanation += [f"Omitting {len(same)} identical items, use -vv to show"]
|
||||
explanation += ["Omitting %s identical items, use -vv to show" % len(same)]
|
||||
elif same:
|
||||
explanation += ["Common items:"]
|
||||
explanation += highlighter(pprint.pformat(same)).splitlines()
|
||||
explanation += pprint.pformat(same).splitlines()
|
||||
diff = {k for k in common if left[k] != right[k]}
|
||||
if diff:
|
||||
explanation += ["Differing items:"]
|
||||
for k in diff:
|
||||
explanation += [
|
||||
highlighter(saferepr({k: left[k]}))
|
||||
+ " != "
|
||||
+ highlighter(saferepr({k: right[k]}))
|
||||
]
|
||||
explanation += [saferepr({k: left[k]}) + " != " + saferepr({k: right[k]})]
|
||||
extra_left = set_left - set_right
|
||||
len_extra_left = len(extra_left)
|
||||
if len_extra_left:
|
||||
explanation.append(
|
||||
f"Left contains {len_extra_left} more item{'' if len_extra_left == 1 else 's'}:"
|
||||
"Left contains %d more item%s:"
|
||||
% (len_extra_left, "" if len_extra_left == 1 else "s")
|
||||
)
|
||||
explanation.extend(
|
||||
highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines()
|
||||
pprint.pformat({k: left[k] for k in extra_left}).splitlines()
|
||||
)
|
||||
extra_right = set_right - set_left
|
||||
len_extra_right = len(extra_right)
|
||||
if len_extra_right:
|
||||
explanation.append(
|
||||
f"Right contains {len_extra_right} more item{'' if len_extra_right == 1 else 's'}:"
|
||||
"Right contains %d more item%s:"
|
||||
% (len_extra_right, "" if len_extra_right == 1 else "s")
|
||||
)
|
||||
explanation.extend(
|
||||
highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines()
|
||||
pprint.pformat({k: right[k] for k in extra_right}).splitlines()
|
||||
)
|
||||
return explanation
|
||||
|
||||
|
||||
def _compare_eq_cls(
|
||||
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
|
||||
) -> list[str]:
|
||||
def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
|
||||
if not has_default_eq(left):
|
||||
return []
|
||||
if isdatacls(left):
|
||||
@@ -572,37 +475,35 @@ def _compare_eq_cls(
|
||||
if same or diff:
|
||||
explanation += [""]
|
||||
if same and verbose < 2:
|
||||
explanation.append(f"Omitting {len(same)} identical items, use -vv to show")
|
||||
explanation.append("Omitting %s identical items, use -vv to show" % len(same))
|
||||
elif same:
|
||||
explanation += ["Matching attributes:"]
|
||||
explanation += highlighter(pprint.pformat(same)).splitlines()
|
||||
explanation += pprint.pformat(same).splitlines()
|
||||
if diff:
|
||||
explanation += ["Differing attributes:"]
|
||||
explanation += highlighter(pprint.pformat(diff)).splitlines()
|
||||
explanation += pprint.pformat(diff).splitlines()
|
||||
for field in diff:
|
||||
field_left = getattr(left, field)
|
||||
field_right = getattr(right, field)
|
||||
explanation += [
|
||||
"",
|
||||
f"Drill down into differing attribute {field}:",
|
||||
f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}",
|
||||
"Drill down into differing attribute %s:" % field,
|
||||
("%s%s: %r != %r") % (indent, field, field_left, field_right),
|
||||
]
|
||||
explanation += [
|
||||
indent + line
|
||||
for line in _compare_eq_any(
|
||||
field_left, field_right, highlighter, verbose
|
||||
)
|
||||
for line in _compare_eq_any(field_left, field_right, verbose)
|
||||
]
|
||||
return explanation
|
||||
|
||||
|
||||
def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]:
|
||||
def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
|
||||
index = text.find(term)
|
||||
head = text[:index]
|
||||
tail = text[index + len(term) :]
|
||||
correct_text = head + tail
|
||||
diff = _diff_text(text, correct_text, dummy_highlighter, verbose)
|
||||
newdiff = [f"{saferepr(term, maxsize=42)} is contained here:"]
|
||||
diff = _diff_text(text, correct_text, verbose)
|
||||
newdiff = ["%s is contained here:" % saferepr(term, maxsize=42)]
|
||||
for line in diff:
|
||||
if line.startswith("Skipping"):
|
||||
continue
|
||||
@@ -613,3 +514,9 @@ def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]:
|
||||
else:
|
||||
newdiff.append(line)
|
||||
return newdiff
|
||||
|
||||
|
||||
def running_on_ci() -> bool:
|
||||
"""Check if we're currently running on a CI system."""
|
||||
env_vars = ["CI", "BUILD_NUMBER"]
|
||||
return any(var in os.environ for var in env_vars)
|
||||
|
||||
Reference in New Issue
Block a user