This commit is contained in:
Iliyan Angelov
2025-11-23 18:59:18 +02:00
parent be07802066
commit 627959f52b
1840 changed files with 236564 additions and 3475 deletions

7
Backend/venv/bin/coverage Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from coverage.cmdline import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/coverage-3.12 Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from coverage.cmdline import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/coverage3 Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from coverage.cmdline import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/httpx Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from httpx import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -1,8 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python3
# -*- coding: utf-8 -*-
import re
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -1,8 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python3
# -*- coding: utf-8 -*-
import re
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

View File

@@ -1,8 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python3
# -*- coding: utf-8 -*-
import re
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/py.test Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from pytest import console_main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(console_main())

7
Backend/venv/bin/pygmentize Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from pygments.cmdline import main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(main())

7
Backend/venv/bin/pytest Executable file
View File

@@ -0,0 +1,7 @@
#!/home/gnx/Desktop/Hotel-Booking/Backend/venv/bin/python
import sys
from pytest import console_main
if __name__ == '__main__':
if sys.argv[0].endswith('.exe'):
sys.argv[0] = sys.argv[0][:-4]
sys.exit(console_main())

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
__all__ = ["__version__", "version_tuple"]
try:
from ._version import version as __version__
from ._version import version_tuple
except ImportError: # pragma: no cover
# broken installation, we don't even try
# unknown only works because we do poor mans version compare
__version__ = "unknown"
version_tuple = (0, 0, "unknown")

View File

@@ -0,0 +1,117 @@
"""Allow bash-completion for argparse with argcomplete if installed.
Needs argcomplete>=0.5.6 for python 3.2/3.3 (older versions fail
to find the magic string, so _ARGCOMPLETE env. var is never set, and
this does not need special code).
Function try_argcomplete(parser) should be called directly before
the call to ArgumentParser.parse_args().
The filescompleter is what you normally would use on the positional
arguments specification, in order to get "dirname/" after "dirn<TAB>"
instead of the default "dirname ":
optparser.add_argument(Config._file_or_dir, nargs='*').completer=filescompleter
Other, application specific, completers should go in the file
doing the add_argument calls as they need to be specified as .completer
attributes as well. (If argcomplete is not installed, the function the
attribute points to will not be used).
SPEEDUP
=======
The generic argcomplete script for bash-completion
(/etc/bash_completion.d/python-argcomplete.sh)
uses a python program to determine startup script generated by pip.
You can speed up completion somewhat by changing this script to include
# PYTHON_ARGCOMPLETE_OK
so the python-argcomplete-check-easy-install-script does not
need to be called to find the entry point of the code and see if that is
marked with PYTHON_ARGCOMPLETE_OK.
INSTALL/DEBUGGING
=================
To include this support in another application that has setup.py generated
scripts:
- Add the line:
# PYTHON_ARGCOMPLETE_OK
near the top of the main python entry point.
- Include in the file calling parse_args():
from _argcomplete import try_argcomplete, filescompleter
Call try_argcomplete just before parse_args(), and optionally add
filescompleter to the positional arguments' add_argument().
If things do not work right away:
- Switch on argcomplete debugging with (also helpful when doing custom
completers):
export _ARC_DEBUG=1
- Run:
python-argcomplete-check-easy-install-script $(which appname)
echo $?
will echo 0 if the magic line has been found, 1 if not.
- Sometimes it helps to find early on errors using:
_ARGCOMPLETE=1 _ARC_DEBUG=1 appname
which should throw a KeyError: 'COMPLINE' (which is properly set by the
global argcomplete script).
"""
from __future__ import annotations
import argparse
from glob import glob
import os
import sys
from typing import Any
class FastFilesCompleter:
"""Fast file completer class."""
def __init__(self, directories: bool = True) -> None:
self.directories = directories
def __call__(self, prefix: str, **kwargs: Any) -> list[str]:
# Only called on non option completions.
if os.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.sep)
else:
prefix_dir = 0
completion = []
globbed = []
if "*" not in prefix and "?" not in prefix:
# We are on unix, otherwise no bash.
if not prefix or prefix[-1] == os.sep:
globbed.extend(glob(prefix + ".*"))
prefix += "*"
globbed.extend(glob(prefix))
for x in sorted(globbed):
if os.path.isdir(x):
x += "/"
# Append stripping the prefix (like bash, not like compgen).
completion.append(x[prefix_dir:])
return completion
if os.environ.get("_ARGCOMPLETE"):
try:
import argcomplete.completers
except ImportError:
sys.exit(-1)
filescompleter: FastFilesCompleter | None = FastFilesCompleter()
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
argcomplete.autocomplete(parser, always_complete_options=False)
else:
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
pass
filescompleter = None

View File

@@ -0,0 +1,26 @@
"""Python inspection/code generation API."""
from __future__ import annotations
from .code import Code
from .code import ExceptionInfo
from .code import filter_traceback
from .code import Frame
from .code import getfslineno
from .code import Traceback
from .code import TracebackEntry
from .source import getrawcode
from .source import Source
__all__ = [
"Code",
"ExceptionInfo",
"Frame",
"Source",
"Traceback",
"TracebackEntry",
"filter_traceback",
"getfslineno",
"getrawcode",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,225 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import ast
from bisect import bisect_right
from collections.abc import Iterable
from collections.abc import Iterator
import inspect
import textwrap
import tokenize
import types
from typing import overload
import warnings
class Source:
"""An immutable object holding a source code fragment.
When using Source(...), the source lines are deindented.
"""
def __init__(self, obj: object = None) -> None:
if not obj:
self.lines: list[str] = []
self.raw_lines: list[str] = []
elif isinstance(obj, Source):
self.lines = obj.lines
self.raw_lines = obj.raw_lines
elif isinstance(obj, tuple | list):
self.lines = deindent(x.rstrip("\n") for x in obj)
self.raw_lines = list(x.rstrip("\n") for x in obj)
elif isinstance(obj, str):
self.lines = deindent(obj.split("\n"))
self.raw_lines = obj.split("\n")
else:
try:
rawcode = getrawcode(obj)
src = inspect.getsource(rawcode)
except TypeError:
src = inspect.getsource(obj) # type: ignore[arg-type]
self.lines = deindent(src.split("\n"))
self.raw_lines = src.split("\n")
def __eq__(self, other: object) -> bool:
if not isinstance(other, Source):
return NotImplemented
return self.lines == other.lines
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@overload
def __getitem__(self, key: int) -> str: ...
@overload
def __getitem__(self, key: slice) -> Source: ...
def __getitem__(self, key: int | slice) -> str | Source:
if isinstance(key, int):
return self.lines[key]
else:
if key.step not in (None, 1):
raise IndexError("cannot slice a Source with a step")
newsource = Source()
newsource.lines = self.lines[key.start : key.stop]
newsource.raw_lines = self.raw_lines[key.start : key.stop]
return newsource
def __iter__(self) -> Iterator[str]:
return iter(self.lines)
def __len__(self) -> int:
return len(self.lines)
def strip(self) -> Source:
"""Return new Source object with trailing and leading blank lines removed."""
start, end = 0, len(self)
while start < end and not self.lines[start].strip():
start += 1
while end > start and not self.lines[end - 1].strip():
end -= 1
source = Source()
source.raw_lines = self.raw_lines
source.lines[:] = self.lines[start:end]
return source
def indent(self, indent: str = " " * 4) -> Source:
"""Return a copy of the source object with all lines indented by the
given indent-string."""
newsource = Source()
newsource.raw_lines = self.raw_lines
newsource.lines = [(indent + line) for line in self.lines]
return newsource
def getstatement(self, lineno: int) -> Source:
"""Return Source statement which contains the given linenumber
(counted from 0)."""
start, end = self.getstatementrange(lineno)
return self[start:end]
def getstatementrange(self, lineno: int) -> tuple[int, int]:
"""Return (start, end) tuple which spans the minimal statement region
which containing the given lineno."""
if not (0 <= lineno < len(self)):
raise IndexError("lineno out of range")
_ast, start, end = getstatementrange_ast(lineno, self)
return start, end
def deindent(self) -> Source:
"""Return a new Source object deindented."""
newsource = Source()
newsource.lines[:] = deindent(self.lines)
newsource.raw_lines = self.raw_lines
return newsource
def __str__(self) -> str:
return "\n".join(self.lines)
#
# helper functions
#
def findsource(obj) -> tuple[Source | None, int]:
try:
sourcelines, lineno = inspect.findsource(obj)
except Exception:
return None, -1
source = Source()
source.lines = [line.rstrip() for line in sourcelines]
source.raw_lines = sourcelines
return source, lineno
def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
"""Return code object for given function."""
try:
return obj.__code__ # type: ignore[attr-defined,no-any-return]
except AttributeError:
pass
if trycall:
call = getattr(obj, "__call__", None)
if call and not isinstance(obj, type):
return getrawcode(call, trycall=False)
raise TypeError(f"could not get code object for {obj!r}")
def deindent(lines: Iterable[str]) -> list[str]:
return textwrap.dedent("\n".join(lines)).splitlines()
def get_statement_startend2(lineno: int, node: ast.AST) -> tuple[int, int | None]:
# Flatten all statements and except handlers into one lineno-list.
# AST's line numbers start indexing at 1.
values: list[int] = []
for x in ast.walk(node):
if isinstance(x, ast.stmt | ast.ExceptHandler):
# The lineno points to the class/def, so need to include the decorators.
if isinstance(x, ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef):
for d in x.decorator_list:
values.append(d.lineno - 1)
values.append(x.lineno - 1)
for name in ("finalbody", "orelse"):
val: list[ast.stmt] | None = getattr(x, name, None)
if val:
# Treat the finally/orelse part as its own statement.
values.append(val[0].lineno - 1 - 1)
values.sort()
insert_index = bisect_right(values, lineno)
start = values[insert_index - 1]
if insert_index >= len(values):
end = None
else:
end = values[insert_index]
return start, end
def getstatementrange_ast(
lineno: int,
source: Source,
assertion: bool = False,
astnode: ast.AST | None = None,
) -> tuple[ast.AST, int, int]:
if astnode is None:
content = str(source)
# See #4260:
# Don't produce duplicate warnings when compiling source to find AST.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
astnode = ast.parse(content, "source", "exec")
start, end = get_statement_startend2(lineno, astnode)
# We need to correct the end:
# - ast-parsing strips comments
# - there might be empty lines
# - we might have lesser indented code blocks at the end
if end is None:
end = len(source.lines)
if end > start + 1:
# Make sure we don't span differently indented code blocks
# by using the BlockFinder helper used which inspect.getsource() uses itself.
block_finder = inspect.BlockFinder()
# If we start with an indented line, put blockfinder to "started" mode.
block_finder.started = (
bool(source.lines[start]) and source.lines[start][0].isspace()
)
it = ((x + "\n") for x in source.lines[start:end])
try:
for tok in tokenize.generate_tokens(lambda: next(it)):
block_finder.tokeneater(*tok)
except (inspect.EndOfBlock, IndentationError):
end = block_finder.last + start
except Exception:
pass
# The end might still point to a comment or empty line, correct it.
while end:
line = source.lines[end - 1].lstrip()
if line.startswith("#") or not line:
end -= 1
else:
break
return astnode, start, end

View File

@@ -0,0 +1,10 @@
from __future__ import annotations
from .terminalwriter import get_terminal_width
from .terminalwriter import TerminalWriter
__all__ = [
"TerminalWriter",
"get_terminal_width",
]

View File

@@ -0,0 +1,673 @@
# mypy: allow-untyped-defs
# This module was imported from the cpython standard library
# (https://github.com/python/cpython/) at commit
# c5140945c723ae6c4b7ee81ff720ac8ea4b52cfd (python3.12).
#
#
# Original Author: Fred L. Drake, Jr.
# fdrake@acm.org
#
# This is a simple little module I wrote to make life easier. I didn't
# see anything quite like it in the library, though I may have overlooked
# something. I wrote this when I was trying to read some heavily nested
# tuples with fairly non-descriptive content. This is modeled very much
# after Lisp/Scheme - style pretty-printing of lists. If you find it
# useful, thank small children who sleep at night.
from __future__ import annotations
import collections as _collections
from collections.abc import Callable
from collections.abc import Iterator
import dataclasses as _dataclasses
from io import StringIO as _StringIO
import re
import types as _types
from typing import Any
from typing import IO
class _safe_key:
"""Helper function for key functions when sorting unorderable objects.
The wrapped-object will fallback to a Py2.x style comparison for
unorderable types (sorting first comparing the type name and then by
the obj ids). Does not work recursively, so dict.items() must have
_safe_key applied to both the key and the value.
"""
__slots__ = ["obj"]
def __init__(self, obj):
self.obj = obj
def __lt__(self, other):
try:
return self.obj < other.obj
except TypeError:
return (str(type(self.obj)), id(self.obj)) < (
str(type(other.obj)),
id(other.obj),
)
def _safe_tuple(t):
"""Helper function for comparing 2-tuples"""
return _safe_key(t[0]), _safe_key(t[1])
class PrettyPrinter:
def __init__(
self,
indent: int = 4,
width: int = 80,
depth: int | None = None,
) -> None:
"""Handle pretty printing operations onto a stream using a set of
configured parameters.
indent
Number of spaces to indent for each level of nesting.
width
Attempted maximum number of columns in the output.
depth
The maximum depth to print out nested structures.
"""
if indent < 0:
raise ValueError("indent must be >= 0")
if depth is not None and depth <= 0:
raise ValueError("depth must be > 0")
if not width:
raise ValueError("width must be != 0")
self._depth = depth
self._indent_per_level = indent
self._width = width
def pformat(self, object: Any) -> str:
sio = _StringIO()
self._format(object, sio, 0, 0, set(), 0)
return sio.getvalue()
def _format(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
objid = id(object)
if objid in context:
stream.write(_recursion(object))
return
p = self._dispatch.get(type(object).__repr__, None)
if p is not None:
context.add(objid)
p(self, object, stream, indent, allowance, context, level + 1)
context.remove(objid)
elif (
_dataclasses.is_dataclass(object)
and not isinstance(object, type)
and object.__dataclass_params__.repr # type:ignore[attr-defined]
and
# Check dataclass has generated repr method.
hasattr(object.__repr__, "__wrapped__")
and "__create_fn__" in object.__repr__.__wrapped__.__qualname__
):
context.add(objid)
self._pprint_dataclass(
object, stream, indent, allowance, context, level + 1
)
context.remove(objid)
else:
stream.write(self._repr(object, context, level))
def _pprint_dataclass(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
cls_name = object.__class__.__name__
items = [
(f.name, getattr(object, f.name))
for f in _dataclasses.fields(object)
if f.repr
]
stream.write(cls_name + "(")
self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")")
_dispatch: dict[
Callable[..., str],
Callable[[PrettyPrinter, Any, IO[str], int, int, set[int], int], None],
] = {}
def _pprint_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
write("{")
items = sorted(object.items(), key=_safe_tuple)
self._format_dict_items(items, stream, indent, allowance, context, level)
write("}")
_dispatch[dict.__repr__] = _pprint_dict
def _pprint_ordered_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not len(object):
stream.write(repr(object))
return
cls = object.__class__
stream.write(cls.__name__ + "(")
self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict
def _pprint_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write("[")
self._format_items(object, stream, indent, allowance, context, level)
stream.write("]")
_dispatch[list.__repr__] = _pprint_list
def _pprint_tuple(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write("(")
self._format_items(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[tuple.__repr__] = _pprint_tuple
def _pprint_set(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not len(object):
stream.write(repr(object))
return
typ = object.__class__
if typ is set:
stream.write("{")
endchar = "}"
else:
stream.write(typ.__name__ + "({")
endchar = "})"
object = sorted(object, key=_safe_key)
self._format_items(object, stream, indent, allowance, context, level)
stream.write(endchar)
_dispatch[set.__repr__] = _pprint_set
_dispatch[frozenset.__repr__] = _pprint_set
def _pprint_str(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
if not len(object):
write(repr(object))
return
chunks = []
lines = object.splitlines(True)
if level == 1:
indent += 1
allowance += 1
max_width1 = max_width = self._width - indent
for i, line in enumerate(lines):
rep = repr(line)
if i == len(lines) - 1:
max_width1 -= allowance
if len(rep) <= max_width1:
chunks.append(rep)
else:
# A list of alternating (non-space, space) strings
parts = re.findall(r"\S*\s*", line)
assert parts
assert not parts[-1]
parts.pop() # drop empty last part
max_width2 = max_width
current = ""
for j, part in enumerate(parts):
candidate = current + part
if j == len(parts) - 1 and i == len(lines) - 1:
max_width2 -= allowance
if len(repr(candidate)) > max_width2:
if current:
chunks.append(repr(current))
current = part
else:
current = candidate
if current:
chunks.append(repr(current))
if len(chunks) == 1:
write(rep)
return
if level == 1:
write("(")
for i, rep in enumerate(chunks):
if i > 0:
write("\n" + " " * indent)
write(rep)
if level == 1:
write(")")
_dispatch[str.__repr__] = _pprint_str
def _pprint_bytes(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
if len(object) <= 4:
write(repr(object))
return
parens = level == 1
if parens:
indent += 1
allowance += 1
write("(")
delim = ""
for rep in _wrap_bytes_repr(object, self._width - indent, allowance):
write(delim)
write(rep)
if not delim:
delim = "\n" + " " * indent
if parens:
write(")")
_dispatch[bytes.__repr__] = _pprint_bytes
def _pprint_bytearray(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
write("bytearray(")
self._pprint_bytes(
bytes(object), stream, indent + 10, allowance + 1, context, level + 1
)
write(")")
_dispatch[bytearray.__repr__] = _pprint_bytearray
def _pprint_mappingproxy(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write("mappingproxy(")
self._format(object.copy(), stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy
def _pprint_simplenamespace(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if type(object) is _types.SimpleNamespace:
# The SimpleNamespace repr is "namespace" instead of the class
# name, so we do the same here. For subclasses; use the class name.
cls_name = "namespace"
else:
cls_name = object.__class__.__name__
items = object.__dict__.items()
stream.write(cls_name + "(")
self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace
def _format_dict_items(
self,
items: list[tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not items:
return
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
for key, ent in items:
write(delimnl)
write(self._repr(key, context, level))
write(": ")
self._format(ent, stream, item_indent, 1, context, level)
write(",")
write("\n" + " " * indent)
def _format_namespace_items(
self,
items: list[tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not items:
return
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
for key, ent in items:
write(delimnl)
write(key)
write("=")
if id(ent) in context:
# Special-case representation of recursion to match standard
# recursive dataclass repr.
write("...")
else:
self._format(
ent,
stream,
item_indent + len(key) + 1,
1,
context,
level,
)
write(",")
write("\n" + " " * indent)
def _format_items(
self,
items: list[Any],
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not items:
return
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
for item in items:
write(delimnl)
self._format(item, stream, item_indent, 1, context, level)
write(",")
write("\n" + " " * indent)
def _repr(self, object: Any, context: set[int], level: int) -> str:
return self._safe_repr(object, context.copy(), self._depth, level)
def _pprint_default_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
rdf = self._repr(object.default_factory, context, level)
stream.write(f"{object.__class__.__name__}({rdf}, ")
self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict
def _pprint_counter(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
if object:
stream.write("{")
items = object.most_common()
self._format_dict_items(items, stream, indent, allowance, context, level)
stream.write("}")
stream.write(")")
_dispatch[_collections.Counter.__repr__] = _pprint_counter
def _pprint_chain_map(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])):
stream.write(repr(object))
return
stream.write(object.__class__.__name__ + "(")
self._format_items(object.maps, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map
def _pprint_deque(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
if object.maxlen is not None:
stream.write(f"maxlen={object.maxlen}, ")
stream.write("[")
self._format_items(object, stream, indent, allowance + 1, context, level)
stream.write("])")
_dispatch[_collections.deque.__repr__] = _pprint_deque
def _pprint_user_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserDict.__repr__] = _pprint_user_dict
def _pprint_user_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserList.__repr__] = _pprint_user_list
def _pprint_user_string(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserString.__repr__] = _pprint_user_string
def _safe_repr(
self, object: Any, context: set[int], maxlevels: int | None, level: int
) -> str:
typ = type(object)
if typ in _builtin_scalars:
return repr(object)
r = getattr(typ, "__repr__", None)
if issubclass(typ, dict) and r is dict.__repr__:
if not object:
return "{}"
objid = id(object)
if maxlevels and level >= maxlevels:
return "{...}"
if objid in context:
return _recursion(object)
context.add(objid)
components: list[str] = []
append = components.append
level += 1
for k, v in sorted(object.items(), key=_safe_tuple):
krepr = self._safe_repr(k, context, maxlevels, level)
vrepr = self._safe_repr(v, context, maxlevels, level)
append(f"{krepr}: {vrepr}")
context.remove(objid)
return "{{{}}}".format(", ".join(components))
if (issubclass(typ, list) and r is list.__repr__) or (
issubclass(typ, tuple) and r is tuple.__repr__
):
if issubclass(typ, list):
if not object:
return "[]"
format = "[%s]"
elif len(object) == 1:
format = "(%s,)"
else:
if not object:
return "()"
format = "(%s)"
objid = id(object)
if maxlevels and level >= maxlevels:
return format % "..."
if objid in context:
return _recursion(object)
context.add(objid)
components = []
append = components.append
level += 1
for o in object:
orepr = self._safe_repr(o, context, maxlevels, level)
append(orepr)
context.remove(objid)
return format % ", ".join(components)
return repr(object)
_builtin_scalars = frozenset(
{str, bytes, bytearray, float, complex, bool, type(None), int}
)
def _recursion(object: Any) -> str:
return f"<Recursion on {type(object).__name__} with id={id(object)}>"
def _wrap_bytes_repr(object: Any, width: int, allowance: int) -> Iterator[str]:
current = b""
last = len(object) // 4 * 4
for i in range(0, len(object), 4):
part = object[i : i + 4]
candidate = current + part
if i == last:
width -= allowance
if len(repr(candidate)) > width:
if current:
yield repr(current)
current = part
else:
current = candidate
if current:
yield repr(current)

View File

@@ -0,0 +1,130 @@
from __future__ import annotations
import pprint
import reprlib
def _try_repr_or_str(obj: object) -> str:
try:
return repr(obj)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException:
return f'{type(obj).__name__}("{obj}")'
def _format_repr_exception(exc: BaseException, obj: object) -> str:
try:
exc_info = _try_repr_or_str(exc)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as inner_exc:
exc_info = f"unpresentable exception ({_try_repr_or_str(inner_exc)})"
return (
f"<[{exc_info} raised in repr()] {type(obj).__name__} object at 0x{id(obj):x}>"
)
def _ellipsize(s: str, maxsize: int) -> str:
if len(s) > maxsize:
i = max(0, (maxsize - 3) // 2)
j = max(0, maxsize - 3 - i)
return s[:i] + "..." + s[len(s) - j :]
return s
class SafeRepr(reprlib.Repr):
"""
repr.Repr that limits the resulting size of repr() and includes
information on exceptions raised during the call.
"""
def __init__(self, maxsize: int | None, use_ascii: bool = False) -> None:
"""
:param maxsize:
If not None, will truncate the resulting repr to that specific size, using ellipsis
somewhere in the middle to hide the extra text.
If None, will not impose any size limits on the returning repr.
"""
super().__init__()
# ``maxstring`` is used by the superclass, and needs to be an int; using a
# very large number in case maxsize is None, meaning we want to disable
# truncation.
self.maxstring = maxsize if maxsize is not None else 1_000_000_000
self.maxsize = maxsize
self.use_ascii = use_ascii
def repr(self, x: object) -> str:
try:
if self.use_ascii:
s = ascii(x)
else:
s = super().repr(x)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
s = _format_repr_exception(exc, x)
if self.maxsize is not None:
s = _ellipsize(s, self.maxsize)
return s
def repr_instance(self, x: object, level: int) -> str:
try:
s = repr(x)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
s = _format_repr_exception(exc, x)
if self.maxsize is not None:
s = _ellipsize(s, self.maxsize)
return s
def safeformat(obj: object) -> str:
"""Return a pretty printed string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info.
"""
try:
return pprint.pformat(obj)
except Exception as exc:
return _format_repr_exception(exc, obj)
# Maximum size of overall repr of objects to display during assertion errors.
DEFAULT_REPR_MAX_SIZE = 240
def saferepr(
obj: object, maxsize: int | None = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False
) -> str:
"""Return a size-limited safe repr-string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info and 'saferepr' generally takes
care to never raise exceptions itself.
This function is a wrapper around the Repr/reprlib functionality of the
stdlib.
"""
return SafeRepr(maxsize, use_ascii).repr(obj)
def saferepr_unlimited(obj: object, use_ascii: bool = True) -> str:
"""Return an unlimited-size safe repr-string for the given object.
As with saferepr, failing __repr__ functions of user instances
will be represented with a short exception info.
This function is a wrapper around simple repr.
Note: a cleaner solution would be to alter ``saferepr``this way
when maxsize=None, but that might affect some other code.
"""
try:
if use_ascii:
return ascii(obj)
return repr(obj)
except Exception as exc:
return _format_repr_exception(exc, obj)

View File

@@ -0,0 +1,258 @@
"""Helper functions for writing to terminals and files."""
from __future__ import annotations
from collections.abc import Sequence
import os
import shutil
import sys
from typing import final
from typing import Literal
from typing import TextIO
import pygments
from pygments.formatters.terminal import TerminalFormatter
from pygments.lexer import Lexer
from pygments.lexers.diff import DiffLexer
from pygments.lexers.python import PythonLexer
from ..compat import assert_never
from .wcwidth import wcswidth
# This code was initially copied from py 1.8.1, file _io/terminalwriter.py.
def get_terminal_width() -> int:
width, _ = shutil.get_terminal_size(fallback=(80, 24))
# The Windows get_terminal_size may be bogus, let's sanify a bit.
if width < 40:
width = 80
return width
def should_do_markup(file: TextIO) -> bool:
if os.environ.get("PY_COLORS") == "1":
return True
if os.environ.get("PY_COLORS") == "0":
return False
if os.environ.get("NO_COLOR"):
return False
if os.environ.get("FORCE_COLOR"):
return True
return (
hasattr(file, "isatty") and file.isatty() and os.environ.get("TERM") != "dumb"
)
@final
class TerminalWriter:
_esctable = dict(
black=30,
red=31,
green=32,
yellow=33,
blue=34,
purple=35,
cyan=36,
white=37,
Black=40,
Red=41,
Green=42,
Yellow=43,
Blue=44,
Purple=45,
Cyan=46,
White=47,
bold=1,
light=2,
blink=5,
invert=7,
)
def __init__(self, file: TextIO | None = None) -> None:
if file is None:
file = sys.stdout
if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32":
try:
import colorama
except ImportError:
pass
else:
file = colorama.AnsiToWin32(file).stream
assert file is not None
self._file = file
self.hasmarkup = should_do_markup(file)
self._current_line = ""
self._terminal_width: int | None = None
self.code_highlight = True
@property
def fullwidth(self) -> int:
if self._terminal_width is not None:
return self._terminal_width
return get_terminal_width()
@fullwidth.setter
def fullwidth(self, value: int) -> None:
self._terminal_width = value
@property
def width_of_current_line(self) -> int:
"""Return an estimate of the width so far in the current line."""
return wcswidth(self._current_line)
def markup(self, text: str, **markup: bool) -> str:
for name in markup:
if name not in self._esctable:
raise ValueError(f"unknown markup: {name!r}")
if self.hasmarkup:
esc = [self._esctable[name] for name, on in markup.items() if on]
if esc:
text = "".join(f"\x1b[{cod}m" for cod in esc) + text + "\x1b[0m"
return text
def sep(
self,
sepchar: str,
title: str | None = None,
fullwidth: int | None = None,
**markup: bool,
) -> None:
if fullwidth is None:
fullwidth = self.fullwidth
# The goal is to have the line be as long as possible
# under the condition that len(line) <= fullwidth.
if sys.platform == "win32":
# If we print in the last column on windows we are on a
# new line but there is no way to verify/neutralize this
# (we may not know the exact line width).
# So let's be defensive to avoid empty lines in the output.
fullwidth -= 1
if title is not None:
# we want 2 + 2*len(fill) + len(title) <= fullwidth
# i.e. 2 + 2*len(sepchar)*N + len(title) <= fullwidth
# 2*len(sepchar)*N <= fullwidth - len(title) - 2
# N <= (fullwidth - len(title) - 2) // (2*len(sepchar))
N = max((fullwidth - len(title) - 2) // (2 * len(sepchar)), 1)
fill = sepchar * N
line = f"{fill} {title} {fill}"
else:
# we want len(sepchar)*N <= fullwidth
# i.e. N <= fullwidth // len(sepchar)
line = sepchar * (fullwidth // len(sepchar))
# In some situations there is room for an extra sepchar at the right,
# in particular if we consider that with a sepchar like "_ " the
# trailing space is not important at the end of the line.
if len(line) + len(sepchar.rstrip()) <= fullwidth:
line += sepchar.rstrip()
self.line(line, **markup)
def write(self, msg: str, *, flush: bool = False, **markup: bool) -> None:
if msg:
current_line = msg.rsplit("\n", 1)[-1]
if "\n" in msg:
self._current_line = current_line
else:
self._current_line += current_line
msg = self.markup(msg, **markup)
self.write_raw(msg, flush=flush)
def write_raw(self, msg: str, *, flush: bool = False) -> None:
try:
self._file.write(msg)
except UnicodeEncodeError:
# Some environments don't support printing general Unicode
# strings, due to misconfiguration or otherwise; in that case,
# print the string escaped to ASCII.
# When the Unicode situation improves we should consider
# letting the error propagate instead of masking it (see #7475
# for one brief attempt).
msg = msg.encode("unicode-escape").decode("ascii")
self._file.write(msg)
if flush:
self.flush()
def line(self, s: str = "", **markup: bool) -> None:
self.write(s, **markup)
self.write("\n")
def flush(self) -> None:
self._file.flush()
def _write_source(self, lines: Sequence[str], indents: Sequence[str] = ()) -> None:
"""Write lines of source code possibly highlighted.
Keeping this private for now because the API is clunky. We should discuss how
to evolve the terminal writer so we can have more precise color support, for example
being able to write part of a line in one color and the rest in another, and so on.
"""
if indents and len(indents) != len(lines):
raise ValueError(
f"indents size ({len(indents)}) should have same size as lines ({len(lines)})"
)
if not indents:
indents = [""] * len(lines)
source = "\n".join(lines)
new_lines = self._highlight(source).splitlines()
# Would be better to strict=True but that fails some CI jobs.
for indent, new_line in zip(indents, new_lines, strict=False):
self.line(indent + new_line)
def _get_pygments_lexer(self, lexer: Literal["python", "diff"]) -> Lexer:
if lexer == "python":
return PythonLexer()
elif lexer == "diff":
return DiffLexer()
else:
assert_never(lexer)
def _get_pygments_formatter(self) -> TerminalFormatter:
from _pytest.config.exceptions import UsageError
theme = os.getenv("PYTEST_THEME")
theme_mode = os.getenv("PYTEST_THEME_MODE", "dark")
try:
return TerminalFormatter(bg=theme_mode, style=theme)
except pygments.util.ClassNotFound as e:
raise UsageError(
f"PYTEST_THEME environment variable has an invalid value: '{theme}'. "
"Hint: See available pygments styles with `pygmentize -L styles`."
) from e
except pygments.util.OptionError as e:
raise UsageError(
f"PYTEST_THEME_MODE environment variable has an invalid value: '{theme_mode}'. "
"The allowed values are 'dark' (default) and 'light'."
) from e
def _highlight(
self, source: str, lexer: Literal["diff", "python"] = "python"
) -> str:
"""Highlight the given source if we have markup support."""
if not source or not self.hasmarkup or not self.code_highlight:
return source
pygments_lexer = self._get_pygments_lexer(lexer)
pygments_formatter = self._get_pygments_formatter()
highlighted: str = pygments.highlight(
source, pygments_lexer, pygments_formatter
)
# pygments terminal formatter may add a newline when there wasn't one.
# We don't want this, remove.
if highlighted[-1] == "\n" and source[-1] != "\n":
highlighted = highlighted[:-1]
# Some lexers will not set the initial color explicitly
# which may lead to the previous color being propagated to the
# start of the expression, so reset first.
highlighted = "\x1b[0m" + highlighted
return highlighted

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from functools import lru_cache
import unicodedata
@lru_cache(100)
def wcwidth(c: str) -> int:
"""Determine how many columns are needed to display a character in a terminal.
Returns -1 if the character is not printable.
Returns 0, 1 or 2 for other characters.
"""
o = ord(c)
# ASCII fast path.
if 0x20 <= o < 0x07F:
return 1
# Some Cf/Zp/Zl characters which should be zero-width.
if (
o == 0x0000
or 0x200B <= o <= 0x200F
or 0x2028 <= o <= 0x202E
or 0x2060 <= o <= 0x2063
):
return 0
category = unicodedata.category(c)
# Control characters.
if category == "Cc":
return -1
# Combining characters with zero width.
if category in ("Me", "Mn"):
return 0
# Full/Wide east asian characters.
if unicodedata.east_asian_width(c) in ("F", "W"):
return 2
return 1
def wcswidth(s: str) -> int:
"""Determine how many columns are needed to display a string in a terminal.
Returns -1 if the string contains non-printable characters.
"""
width = 0
for c in unicodedata.normalize("NFC", s):
wc = wcwidth(c)
if wc < 0:
return -1
width += wc
return width

View File

@@ -0,0 +1,119 @@
"""create errno-specific classes for IO or os calls."""
from __future__ import annotations
from collections.abc import Callable
import errno
import os
import sys
from typing import TYPE_CHECKING
from typing import TypeVar
if TYPE_CHECKING:
from typing_extensions import ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
class Error(EnvironmentError):
def __repr__(self) -> str:
return "{}.{} {!r}: {} ".format(
self.__class__.__module__,
self.__class__.__name__,
self.__class__.__doc__,
" ".join(map(str, self.args)),
# repr(self.args)
)
def __str__(self) -> str:
s = "[{}]: {}".format(
self.__class__.__doc__,
" ".join(map(str, self.args)),
)
return s
_winerrnomap = {
2: errno.ENOENT,
3: errno.ENOENT,
17: errno.EEXIST,
18: errno.EXDEV,
13: errno.EBUSY, # empty cd drive, but ENOMEDIUM seems unavailable
22: errno.ENOTDIR,
20: errno.ENOTDIR,
267: errno.ENOTDIR,
5: errno.EACCES, # anything better?
}
class ErrorMaker:
"""lazily provides Exception classes for each possible POSIX errno
(as defined per the 'errno' module). All such instances
subclass EnvironmentError.
"""
_errno2class: dict[int, type[Error]] = {}
def __getattr__(self, name: str) -> type[Error]:
if name[0] == "_":
raise AttributeError(name)
eno = getattr(errno, name)
cls = self._geterrnoclass(eno)
setattr(self, name, cls)
return cls
def _geterrnoclass(self, eno: int) -> type[Error]:
try:
return self._errno2class[eno]
except KeyError:
clsname = errno.errorcode.get(eno, f"UnknownErrno{eno}")
errorcls = type(
clsname,
(Error,),
{"__module__": "py.error", "__doc__": os.strerror(eno)},
)
self._errno2class[eno] = errorcls
return errorcls
def checked_call(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
"""Call a function and raise an errno-exception if applicable."""
__tracebackhide__ = True
try:
return func(*args, **kwargs)
except Error:
raise
except OSError as value:
if not hasattr(value, "errno"):
raise
if sys.platform == "win32":
try:
# error: Invalid index type "Optional[int]" for "dict[int, int]"; expected type "int" [index]
# OK to ignore because we catch the KeyError below.
cls = self._geterrnoclass(_winerrnomap[value.errno]) # type:ignore[index]
except KeyError:
raise value
else:
# we are not on Windows, or we got a proper OSError
if value.errno is None:
cls = type(
"UnknownErrnoNone",
(Error,),
{"__module__": "py.error", "__doc__": None},
)
else:
cls = self._geterrnoclass(value.errno)
raise cls(f"{func.__name__}{args!r}")
_error_maker = ErrorMaker()
checked_call = _error_maker.checked_call
def __getattr__(attr: str) -> type[Error]:
return getattr(_error_maker, attr) # type: ignore[no-any-return]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '9.0.1'
__version_tuple__ = version_tuple = (9, 0, 1)
__commit_id__ = commit_id = None

View File

@@ -0,0 +1,208 @@
# mypy: allow-untyped-defs
"""Support for presenting detailed information in failing assertions."""
from __future__ import annotations
from collections.abc import Generator
import sys
from typing import Any
from typing import Protocol
from typing import TYPE_CHECKING
from _pytest.assertion import rewrite
from _pytest.assertion import truncate
from _pytest.assertion import util
from _pytest.assertion.rewrite import assertstate_key
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
if TYPE_CHECKING:
from _pytest.main import Session
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--assert",
action="store",
dest="assertmode",
choices=("rewrite", "plain"),
default="rewrite",
metavar="MODE",
help=(
"Control assertion debugging tools.\n"
"'plain' performs no assertion debugging.\n"
"'rewrite' (the default) rewrites assert statements in test modules"
" on import to provide assert expression information."
),
)
parser.addini(
"enable_assertion_pass_hook",
type="bool",
default=False,
help="Enables the pytest_assertion_pass hook. "
"Make sure to delete any previously generated pyc cache files.",
)
parser.addini(
"truncation_limit_lines",
default=None,
help="Set threshold of LINES after which truncation will take effect",
)
parser.addini(
"truncation_limit_chars",
default=None,
help=("Set threshold of CHARS after which truncation will take effect"),
)
Config._add_verbosity_ini(
parser,
Config.VERBOSITY_ASSERTIONS,
help=(
"Specify a verbosity level for assertions, overriding the main level. "
"Higher levels will provide more detailed explanation when an assertion fails."
),
)
def register_assert_rewrite(*names: str) -> None:
"""Register one or more module names to be rewritten on import.
This function will make sure that this module or all modules inside
the package will get their assert statements rewritten.
Thus you should make sure to call this before the module is
actually imported, usually in your __init__.py if you are a plugin
using a package.
:param names: The module names to register.
"""
for name in names:
if not isinstance(name, str):
msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
raise TypeError(msg.format(repr(names)))
rewrite_hook: RewriteHook
for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook):
rewrite_hook = hook
break
else:
rewrite_hook = DummyRewriteHook()
rewrite_hook.mark_rewrite(*names)
class RewriteHook(Protocol):
def mark_rewrite(self, *names: str) -> None: ...
class DummyRewriteHook:
"""A no-op import hook for when rewriting is disabled."""
def mark_rewrite(self, *names: str) -> None:
pass
class AssertionState:
"""State for the assertion plugin."""
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.hook: rewrite.AssertionRewritingHook | None = None
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails."""
config.stash[assertstate_key] = AssertionState(config, "rewrite")
config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
config.stash[assertstate_key].trace("installed rewrite import hook")
def undo() -> None:
hook = config.stash[assertstate_key].hook
if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook)
config.add_cleanup(undo)
return hook
def pytest_collection(session: Session) -> None:
# This hook is only called when test modules are collected
# so for example not in the managing process of pytest-xdist
# (which does not collect test modules).
assertstate = session.config.stash.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(session)
@hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
The rewrite module will use util._reprcompare if it exists to use custom
reporting via the pytest_assertrepr_compare hook. This sets up this custom
comparison for the test.
"""
ihook = item.ihook
def callbinrepr(op, left: object, right: object) -> str | None:
"""Call the pytest_assertrepr_compare hook and prepare the result.
This uses the first result from the hook and then ensures the
following:
* Overly verbose explanations are truncated unless configured otherwise
(eg. if running in verbose mode).
* Embedded newlines are escaped to help util.format_explanation()
later.
* If the rewrite mode is used embedded %-characters are replaced
to protect later % formatting.
The result can be formatted by util.format_explanation() for
pretty printing.
"""
hook_result = ihook.pytest_assertrepr_compare(
config=item.config, op=op, left=left, right=right
)
for new_expl in hook_result:
if new_expl:
new_expl = truncate.truncate_if_required(new_expl, item)
new_expl = [line.replace("\n", "\\n") for line in new_expl]
res = "\n~".join(new_expl)
if item.config.getvalue("assertmode") == "rewrite":
res = res.replace("%", "%%")
return res
return None
saved_assert_hooks = util._reprcompare, util._assertion_pass
util._reprcompare = callbinrepr
util._config = item.config
if ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
util._assertion_pass = call_assertion_pass_hook
try:
return (yield)
finally:
util._reprcompare, util._assertion_pass = saved_assert_hooks
util._config = None
def pytest_sessionfinish(session: Session) -> None:
assertstate = session.config.stash.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(None)
def pytest_assertrepr_compare(
config: Config, op: str, left: Any, right: Any
) -> list[str] | None:
return util.assertrepr_compare(config=config, op=op, left=left, right=right)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
"""Utilities for truncating assertion output.
Current default behaviour is to truncate assertion explanations at
terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI.
"""
from __future__ import annotations
from _pytest.compat import running_on_ci
from _pytest.config import Config
from _pytest.nodes import Item
DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = DEFAULT_MAX_LINES * 80
USAGE_MSG = "use '-vv' to show"
def truncate_if_required(explanation: list[str], item: Item) -> list[str]:
"""Truncate this assertion explanation if the given test item is eligible."""
should_truncate, max_lines, max_chars = _get_truncation_parameters(item)
if should_truncate:
return _truncate_explanation(
explanation,
max_lines=max_lines,
max_chars=max_chars,
)
return explanation
def _get_truncation_parameters(item: Item) -> tuple[bool, int, int]:
"""Return the truncation parameters related to the given item, as (should truncate, max lines, max chars)."""
# We do not need to truncate if one of conditions is met:
# 1. Verbosity level is 2 or more;
# 2. Test is being run in CI environment;
# 3. Both truncation_limit_lines and truncation_limit_chars
# .ini parameters are set to 0 explicitly.
max_lines = item.config.getini("truncation_limit_lines")
max_lines = int(max_lines if max_lines is not None else DEFAULT_MAX_LINES)
max_chars = item.config.getini("truncation_limit_chars")
max_chars = int(max_chars if max_chars is not None else DEFAULT_MAX_CHARS)
verbose = item.config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
should_truncate = verbose < 2 and not running_on_ci()
should_truncate = should_truncate and (max_lines > 0 or max_chars > 0)
return should_truncate, max_lines, max_chars
def _truncate_explanation(
input_lines: list[str],
max_lines: int,
max_chars: int,
) -> list[str]:
"""Truncate given list of strings that makes up the assertion explanation.
Truncates to either max_lines, or max_chars - whichever the input reaches
first, taking the truncation explanation into account. The remaining lines
will be replaced by a usage message.
"""
# Check if truncation required
input_char_count = len("".join(input_lines))
# The length of the truncation explanation depends on the number of lines
# removed but is at least 68 characters:
# The real value is
# 64 (for the base message:
# '...\n...Full output truncated (1 line hidden), use '-vv' to show")'
# )
# + 1 (for plural)
# + int(math.log10(len(input_lines) - max_lines)) (number of hidden line, at least 1)
# + 3 for the '...' added to the truncated line
# But if there's more than 100 lines it's very likely that we're going to
# truncate, so we don't need the exact value using log10.
tolerable_max_chars = (
max_chars + 70 # 64 + 1 (for plural) + 2 (for '99') + 3 for '...'
)
# The truncation explanation add two lines to the output
tolerable_max_lines = max_lines + 2
if (
len(input_lines) <= tolerable_max_lines
and input_char_count <= tolerable_max_chars
):
return input_lines
# Truncate first to max_lines, and then truncate to max_chars if necessary
if max_lines > 0:
truncated_explanation = input_lines[:max_lines]
else:
truncated_explanation = input_lines
truncated_char = True
# We reevaluate the need to truncate chars following removal of some lines
if len("".join(truncated_explanation)) > tolerable_max_chars and max_chars > 0:
truncated_explanation = _truncate_by_char_count(
truncated_explanation, max_chars
)
else:
truncated_char = False
if truncated_explanation == input_lines:
# No truncation happened, so we do not need to add any explanations
return truncated_explanation
truncated_line_count = len(input_lines) - len(truncated_explanation)
if truncated_explanation[-1]:
# Add ellipsis and take into account part-truncated final line
truncated_explanation[-1] = truncated_explanation[-1] + "..."
if truncated_char:
# It's possible that we did not remove any char from this line
truncated_line_count += 1
else:
# Add proper ellipsis when we were able to fit a full line exactly
truncated_explanation[-1] = "..."
return [
*truncated_explanation,
"",
f"...Full output truncated ({truncated_line_count} line"
f"{'' if truncated_line_count == 1 else 's'} hidden), {USAGE_MSG}",
]
def _truncate_by_char_count(input_lines: list[str], max_chars: int) -> list[str]:
# Find point at which input length exceeds total allowed length
iterated_char_count = 0
for iterated_index, input_line in enumerate(input_lines):
if iterated_char_count + len(input_line) > max_chars:
break
iterated_char_count += len(input_line)
# Create truncated explanation with modified final line
truncated_result = input_lines[:iterated_index]
final_line = input_lines[iterated_index]
if final_line:
final_line_truncate_point = max_chars - iterated_char_count
final_line = final_line[:final_line_truncate_point]
truncated_result.append(final_line)
return truncated_result

View File

@@ -0,0 +1,615 @@
# mypy: allow-untyped-defs
"""Utilities for assertion debugging."""
from __future__ import annotations
import collections.abc
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sequence
from collections.abc import Set as AbstractSet
import pprint
from typing import Any
from typing import Literal
from typing import Protocol
from unicodedata import normalize
from _pytest import outcomes
import _pytest._code
from _pytest._io.pprint import PrettyPrinter
from _pytest._io.saferepr import saferepr
from _pytest._io.saferepr import saferepr_unlimited
from _pytest.compat import running_on_ci
from _pytest.config import Config
# The _reprcompare attribute on the util module is used by the new assertion
# interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the
# DebugInterpreter.
_reprcompare: Callable[[str, object, object], str | None] | None = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass: Callable[[int, str, str], None] | None = None
# Config object which is assigned during pytest_runtest_protocol.
_config: Config | None = None
class _HighlightFunc(Protocol):
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Apply highlighting to the given source."""
def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Dummy highlighter that returns the text unprocessed.
Needed for _notin_text, as the diff gets post-processed to only show the "+" part.
"""
return source
def format_explanation(explanation: str) -> str:
r"""Format an explanation.
Normally all embedded newlines are escaped, however there are
three exceptions: \n{, \n} and \n~. The first two are intended
cover nested explanations, see function and attribute explanations
for examples (.visit_Call(), visit_Attribute()). The last one is
for when one explanation needs to span multiple lines, e.g. when
displaying diffs.
"""
lines = _split_explanation(explanation)
result = _format_lines(lines)
return "\n".join(result)
def _split_explanation(explanation: str) -> list[str]:
r"""Return a list of individual lines in the explanation.
This will return a list of lines split on '\n{', '\n}' and '\n~'.
Any other newlines will be escaped and appear in the line as the
literal '\n' characters.
"""
raw_lines = (explanation or "").split("\n")
lines = [raw_lines[0]]
for values in raw_lines[1:]:
if values and values[0] in ["{", "}", "~", ">"]:
lines.append(values)
else:
lines[-1] += "\\n" + values
return lines
def _format_lines(lines: Sequence[str]) -> list[str]:
"""Format the individual lines.
This will replace the '{', '}' and '~' characters of our mini formatting
language with the proper 'where ...', 'and ...' and ' + ...' text, taking
care of indentation along the way.
Return a list of formatted lines.
"""
result = list(lines[:1])
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith("{"):
if stackcnt[-1]:
s = "and "
else:
s = "where "
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
result.append(" +" + " " * (len(stack) - 1) + s + line[1:])
elif line.startswith("}"):
stack.pop()
stackcnt.pop()
result[stack[-1]] += line[1:]
else:
assert line[0] in ["~", ">"]
stack[-1] += 1
indent = len(stack) if line.startswith("~") else len(stack) - 1
result.append(" " * indent + line[1:])
assert len(stack) == 1
return result
def issequence(x: Any) -> bool:
return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
def istext(x: Any) -> bool:
return isinstance(x, str)
def isdict(x: Any) -> bool:
return isinstance(x, dict)
def isset(x: Any) -> bool:
return isinstance(x, set | frozenset)
def isnamedtuple(obj: Any) -> bool:
return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
def isdatacls(obj: Any) -> bool:
return getattr(obj, "__dataclass_fields__", None) is not None
def isattrs(obj: Any) -> bool:
return getattr(obj, "__attrs_attrs__", None) is not None
def isiterable(obj: Any) -> bool:
try:
iter(obj)
return not istext(obj)
except Exception:
return False
def has_default_eq(
obj: object,
) -> bool:
"""Check if an instance of an object contains the default eq
First, we check if the object's __eq__ attribute has __code__,
if so, we check the equally of the method code filename (__code__.co_filename)
to the default one generated by the dataclass and attr module
for dataclasses the default co_filename is <string>, for attrs class, the __eq__ should contain "attrs eq generated"
"""
# inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68
if hasattr(obj.__eq__, "__code__") and hasattr(obj.__eq__.__code__, "co_filename"):
code_filename = obj.__eq__.__code__.co_filename
if isattrs(obj):
return "attrs generated " in code_filename
return code_filename == "<string>" # data class
return True
def assertrepr_compare(
config, op: str, left: Any, right: Any, use_ascii: bool = False
) -> list[str] | None:
"""Return specialised explanations for some operators/operands."""
verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
# Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier.
# See issue #3246.
use_ascii = (
isinstance(left, str)
and isinstance(right, str)
and normalize("NFD", left) == normalize("NFD", right)
)
if verbose > 1:
left_repr = saferepr_unlimited(left, use_ascii=use_ascii)
right_repr = saferepr_unlimited(right, use_ascii=use_ascii)
else:
# XXX: "15 chars indentation" is wrong
# ("E AssertionError: assert "); should use term width.
maxsize = (
80 - 15 - len(op) - 2
) // 2 # 15 chars indentation, 1 space around op
left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii)
right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii)
summary = f"{left_repr} {op} {right_repr}"
highlighter = config.get_terminal_writer()._highlight
explanation = None
try:
if op == "==":
explanation = _compare_eq_any(left, right, highlighter, verbose)
elif op == "not in":
if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose)
elif op == "!=":
if isset(left) and isset(right):
explanation = ["Both sets are equal"]
elif op == ">=":
if isset(left) and isset(right):
explanation = _compare_gte_set(left, right, highlighter, verbose)
elif op == "<=":
if isset(left) and isset(right):
explanation = _compare_lte_set(left, right, highlighter, verbose)
elif op == ">":
if isset(left) and isset(right):
explanation = _compare_gt_set(left, right, highlighter, verbose)
elif op == "<":
if isset(left) and isset(right):
explanation = _compare_lt_set(left, right, highlighter, verbose)
except outcomes.Exit:
raise
except Exception:
repr_crash = _pytest._code.ExceptionInfo.from_current()._getreprcrash()
explanation = [
f"(pytest_assertion plugin: representation of details failed: {repr_crash}.",
" Probably an object has a faulty __repr__.)",
]
if not explanation:
return None
if explanation[0] != "":
explanation = ["", *explanation]
return [summary, *explanation]
def _compare_eq_any(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
) -> list[str]:
explanation = []
if istext(left) and istext(right):
explanation = _diff_text(left, right, highlighter, verbose)
else:
from _pytest.python_api import ApproxBase
if isinstance(left, ApproxBase) or isinstance(right, ApproxBase):
# Although the common order should be obtained == expected, this ensures both ways
approx_side = left if isinstance(left, ApproxBase) else right
other_side = right if isinstance(left, ApproxBase) else left
explanation = approx_side._repr_compare(other_side)
elif type(left) is type(right) and (
isdatacls(left) or isattrs(left) or isnamedtuple(left)
):
# Note: unlike dataclasses/attrs, namedtuples compare only the
# field values, not the type or field names. But this branch
# intentionally only handles the same-type case, which was often
# used in older code bases before dataclasses/attrs were available.
explanation = _compare_eq_cls(left, right, highlighter, verbose)
elif issequence(left) and issequence(right):
explanation = _compare_eq_sequence(left, right, highlighter, verbose)
elif isset(left) and isset(right):
explanation = _compare_eq_set(left, right, highlighter, verbose)
elif isdict(left) and isdict(right):
explanation = _compare_eq_dict(left, right, highlighter, verbose)
if isiterable(left) and isiterable(right):
expl = _compare_eq_iterable(left, right, highlighter, verbose)
explanation.extend(expl)
return explanation
def _diff_text(
left: str, right: str, highlighter: _HighlightFunc, verbose: int = 0
) -> list[str]:
"""Return the explanation for the diff between text.
Unless --verbose is used this will skip leading and trailing
characters which are identical to keep the diff minimal.
"""
from difflib import ndiff
explanation: list[str] = []
if verbose < 1:
i = 0 # just in case left or right has zero length
for i in range(min(len(left), len(right))):
if left[i] != right[i]:
break
if i > 42:
i -= 10 # Provide some context
explanation = [
f"Skipping {i} identical leading characters in diff, use -v to show"
]
left = left[i:]
right = right[i:]
if len(left) == len(right):
for i in range(len(left)):
if left[-i] != right[-i]:
break
if i > 42:
i -= 10 # Provide some context
explanation += [
f"Skipping {i} identical trailing "
"characters in diff, use -v to show"
]
left = left[:-i]
right = right[:-i]
keepends = True
if left.isspace() or right.isspace():
left = repr(str(left))
right = repr(str(right))
explanation += ["Strings contain only whitespace, escaping them using repr()"]
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
highlighter(
"\n".join(
line.strip("\n")
for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
),
lexer="diff",
).splitlines()
)
return explanation
def _compare_eq_iterable(
left: Iterable[Any],
right: Iterable[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
if verbose <= 0 and not running_on_ci():
return ["Use -v to get more diff"]
# dynamic import to speedup pytest
import difflib
left_formatting = PrettyPrinter().pformat(left).splitlines()
right_formatting = PrettyPrinter().pformat(right).splitlines()
explanation = ["", "Full diff:"]
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
highlighter(
"\n".join(
line.rstrip()
for line in difflib.ndiff(right_formatting, left_formatting)
),
lexer="diff",
).splitlines()
)
return explanation
def _compare_eq_sequence(
left: Sequence[Any],
right: Sequence[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation: list[str] = []
len_left = len(left)
len_right = len(right)
for i in range(min(len_left, len_right)):
if left[i] != right[i]:
if comparing_bytes:
# when comparing bytes, we want to see their ascii representation
# instead of their numeric values (#5260)
# using a slice gives us the ascii representation:
# >>> s = b'foo'
# >>> s[0]
# 102
# >>> s[0:1]
# b'f'
left_value = left[i : i + 1]
right_value = right[i : i + 1]
else:
left_value = left[i]
right_value = right[i]
explanation.append(
f"At index {i} diff:"
f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}"
)
break
if comparing_bytes:
# when comparing bytes, it doesn't help to show the "sides contain one or more
# items" longer explanation, so skip it
return explanation
len_diff = len_left - len_right
if len_diff:
if len_diff > 0:
dir_with_more = "Left"
extra = saferepr(left[len_right])
else:
len_diff = 0 - len_diff
dir_with_more = "Right"
extra = saferepr(right[len_left])
if len_diff == 1:
explanation += [
f"{dir_with_more} contains one more item: {highlighter(extra)}"
]
else:
explanation += [
f"{dir_with_more} contains {len_diff} more items, first extra item: {highlighter(extra)}"
]
return explanation
def _compare_eq_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = []
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
return explanation
def _compare_gt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_gte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation
def _compare_lt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_lte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation
def _compare_gte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("right", right, left, highlighter)
def _compare_lte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("left", left, right, highlighter)
def _set_one_sided_diff(
posn: str,
set1: AbstractSet[Any],
set2: AbstractSet[Any],
highlighter: _HighlightFunc,
) -> list[str]:
explanation = []
diff = set1 - set2
if diff:
explanation.append(f"Extra items in the {posn} set:")
for item in diff:
explanation.append(highlighter(saferepr(item)))
return explanation
def _compare_eq_dict(
left: Mapping[Any, Any],
right: Mapping[Any, Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation: list[str] = []
set_left = set(left)
set_right = set(right)
common = set_left.intersection(set_right)
same = {k: left[k] for k in common if left[k] == right[k]}
if same and verbose < 2:
explanation += [f"Omitting {len(same)} identical items, use -vv to show"]
elif same:
explanation += ["Common items:"]
explanation += highlighter(pprint.pformat(same)).splitlines()
diff = {k for k in common if left[k] != right[k]}
if diff:
explanation += ["Differing items:"]
for k in diff:
explanation += [
highlighter(saferepr({k: left[k]}))
+ " != "
+ highlighter(saferepr({k: right[k]}))
]
extra_left = set_left - set_right
len_extra_left = len(extra_left)
if len_extra_left:
explanation.append(
f"Left contains {len_extra_left} more item{'' if len_extra_left == 1 else 's'}:"
)
explanation.extend(
highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines()
)
extra_right = set_right - set_left
len_extra_right = len(extra_right)
if len_extra_right:
explanation.append(
f"Right contains {len_extra_right} more item{'' if len_extra_right == 1 else 's'}:"
)
explanation.extend(
highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines()
)
return explanation
def _compare_eq_cls(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
) -> list[str]:
if not has_default_eq(left):
return []
if isdatacls(left):
import dataclasses
all_fields = dataclasses.fields(left)
fields_to_check = [info.name for info in all_fields if info.compare]
elif isattrs(left):
all_fields = left.__attrs_attrs__
fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
elif isnamedtuple(left):
fields_to_check = left._fields
else:
assert False
indent = " "
same = []
diff = []
for field in fields_to_check:
if getattr(left, field) == getattr(right, field):
same.append(field)
else:
diff.append(field)
explanation = []
if same or diff:
explanation += [""]
if same and verbose < 2:
explanation.append(f"Omitting {len(same)} identical items, use -vv to show")
elif same:
explanation += ["Matching attributes:"]
explanation += highlighter(pprint.pformat(same)).splitlines()
if diff:
explanation += ["Differing attributes:"]
explanation += highlighter(pprint.pformat(diff)).splitlines()
for field in diff:
field_left = getattr(left, field)
field_right = getattr(right, field)
explanation += [
"",
f"Drill down into differing attribute {field}:",
f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}",
]
explanation += [
indent + line
for line in _compare_eq_any(
field_left, field_right, highlighter, verbose
)
]
return explanation
def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]:
index = text.find(term)
head = text[:index]
tail = text[index + len(term) :]
correct_text = head + tail
diff = _diff_text(text, correct_text, dummy_highlighter, verbose)
newdiff = [f"{saferepr(term, maxsize=42)} is contained here:"]
for line in diff:
if line.startswith("Skipping"):
continue
if line.startswith("- "):
continue
if line.startswith("+ "):
newdiff.append(" " + line[2:])
else:
newdiff.append(line)
return newdiff

View File

@@ -0,0 +1,646 @@
# mypy: allow-untyped-defs
"""Implementation of the cache provider."""
# This plugin was not named "cache" to avoid conflicts with the external
# pytest-cache version.
from __future__ import annotations
from collections.abc import Generator
from collections.abc import Iterable
import dataclasses
import errno
import json
import os
from pathlib import Path
import tempfile
from typing import final
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.nodes import Directory
from _pytest.nodes import File
from _pytest.reports import TestReport
README_CONTENT = """\
# pytest cache directory #
This directory contains data from the pytest's cache plugin,
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
**Do not** commit this to version control.
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
"""
CACHEDIR_TAG_CONTENT = b"""\
Signature: 8a477f597d28d172789f06886806bc55
# This file is a cache directory tag created by pytest.
# For information about cache directory tags, see:
# https://bford.info/cachedir/spec.html
"""
@final
@dataclasses.dataclass
class Cache:
"""Instance of the `cache` fixture."""
_cachedir: Path = dataclasses.field(repr=False)
_config: Config = dataclasses.field(repr=False)
# Sub-directory under cache-dir for directories created by `mkdir()`.
_CACHE_PREFIX_DIRS = "d"
# Sub-directory under cache-dir for values created by `set()`.
_CACHE_PREFIX_VALUES = "v"
def __init__(
self, cachedir: Path, config: Config, *, _ispytest: bool = False
) -> None:
check_ispytest(_ispytest)
self._cachedir = cachedir
self._config = config
@classmethod
def for_config(cls, config: Config, *, _ispytest: bool = False) -> Cache:
"""Create the Cache instance for a Config.
:meta private:
"""
check_ispytest(_ispytest)
cachedir = cls.cache_dir_from_config(config, _ispytest=True)
if config.getoption("cacheclear") and cachedir.is_dir():
cls.clear_cache(cachedir, _ispytest=True)
return cls(cachedir, config, _ispytest=True)
@classmethod
def clear_cache(cls, cachedir: Path, _ispytest: bool = False) -> None:
"""Clear the sub-directories used to hold cached directories and values.
:meta private:
"""
check_ispytest(_ispytest)
for prefix in (cls._CACHE_PREFIX_DIRS, cls._CACHE_PREFIX_VALUES):
d = cachedir / prefix
if d.is_dir():
rm_rf(d)
@staticmethod
def cache_dir_from_config(config: Config, *, _ispytest: bool = False) -> Path:
"""Get the path to the cache directory for a Config.
:meta private:
"""
check_ispytest(_ispytest)
return resolve_from_str(config.getini("cache_dir"), config.rootpath)
def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None:
"""Issue a cache warning.
:meta private:
"""
check_ispytest(_ispytest)
import warnings
from _pytest.warning_types import PytestCacheWarning
warnings.warn(
PytestCacheWarning(fmt.format(**args) if args else fmt),
self._config.hook,
stacklevel=3,
)
def _mkdir(self, path: Path) -> None:
self._ensure_cache_dir_and_supporting_files()
path.mkdir(exist_ok=True, parents=True)
def mkdir(self, name: str) -> Path:
"""Return a directory path object with the given name.
If the directory does not yet exist, it will be created. You can use
it to manage files to e.g. store/retrieve database dumps across test
sessions.
.. versionadded:: 7.0
:param name:
Must be a string not containing a ``/`` separator.
Make sure the name contains your plugin or application
identifiers to prevent clashes with other cache users.
"""
path = Path(name)
if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators")
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
self._mkdir(res)
return res
def _getvaluepath(self, key: str) -> Path:
return self._cachedir.joinpath(self._CACHE_PREFIX_VALUES, Path(key))
def get(self, key: str, default):
"""Return the cached value for the given key.
If no value was yet cached or the value cannot be read, the specified
default is returned.
:param key:
Must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
:param default:
The value to return in case of a cache-miss or invalid cache value.
"""
path = self._getvaluepath(key)
try:
with path.open("r", encoding="UTF-8") as f:
return json.load(f)
except (ValueError, OSError):
return default
def set(self, key: str, value: object) -> None:
"""Save value for the given key.
:param key:
Must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
:param value:
Must be of any combination of basic python types,
including nested types like lists of dictionaries.
"""
path = self._getvaluepath(key)
try:
self._mkdir(path.parent)
except OSError as exc:
self.warn(
f"could not create cache path {path}: {exc}",
_ispytest=True,
)
return
data = json.dumps(value, ensure_ascii=False, indent=2)
try:
f = path.open("w", encoding="UTF-8")
except OSError as exc:
self.warn(
f"cache could not write path {path}: {exc}",
_ispytest=True,
)
else:
with f:
f.write(data)
def _ensure_cache_dir_and_supporting_files(self) -> None:
"""Create the cache dir and its supporting files."""
if self._cachedir.is_dir():
return
self._cachedir.parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory(
prefix="pytest-cache-files-",
dir=self._cachedir.parent,
) as newpath:
path = Path(newpath)
# Reset permissions to the default, see #12308.
# Note: there's no way to get the current umask atomically, eek.
umask = os.umask(0o022)
os.umask(umask)
path.chmod(0o777 - umask)
with open(path.joinpath("README.md"), "x", encoding="UTF-8") as f:
f.write(README_CONTENT)
with open(path.joinpath(".gitignore"), "x", encoding="UTF-8") as f:
f.write("# Created by pytest automatically.\n*\n")
with open(path.joinpath("CACHEDIR.TAG"), "xb") as f:
f.write(CACHEDIR_TAG_CONTENT)
try:
path.rename(self._cachedir)
except OSError as e:
# If 2 concurrent pytests both race to the rename, the loser
# gets "Directory not empty" from the rename. In this case,
# everything is handled so just continue (while letting the
# temporary directory be cleaned up).
# On Windows, the error is a FileExistsError which translates to EEXIST.
if e.errno not in (errno.ENOTEMPTY, errno.EEXIST):
raise
else:
# Create a directory in place of the one we just moved so that
# `TemporaryDirectory`'s cleanup doesn't complain.
#
# TODO: pass ignore_cleanup_errors=True when we no longer support python < 3.10.
# See https://github.com/python/cpython/issues/74168. Note that passing
# delete=False would do the wrong thing in case of errors and isn't supported
# until python 3.12.
path.mkdir()
class LFPluginCollWrapper:
def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin
self._collected_at_least_one_failure = False
@hookimpl(wrapper=True)
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> Generator[None, CollectReport, CollectReport]:
res = yield
if isinstance(collector, Session | Directory):
# Sort any lf-paths to the beginning.
lf_paths = self.lfplugin._last_failed_paths
# Use stable sort to prioritize last failed.
def sort_key(node: nodes.Item | nodes.Collector) -> bool:
return node.path in lf_paths
res.result = sorted(
res.result,
key=sort_key,
reverse=True,
)
elif isinstance(collector, File):
if collector.path in self.lfplugin._last_failed_paths:
result = res.result
lastfailed = self.lfplugin.lastfailed
# Only filter with known failures.
if not self._collected_at_least_one_failure:
if not any(x.nodeid in lastfailed for x in result):
return res
self.lfplugin.config.pluginmanager.register(
LFPluginCollSkipfiles(self.lfplugin), "lfplugin-collskip"
)
self._collected_at_least_one_failure = True
session = collector.session
result[:] = [
x
for x in result
if x.nodeid in lastfailed
# Include any passed arguments (not trivial to filter).
or session.isinitpath(x.path)
# Keep all sub-collectors.
or isinstance(x, nodes.Collector)
]
return res
class LFPluginCollSkipfiles:
def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin
@hookimpl
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> CollectReport | None:
if isinstance(collector, File):
if collector.path not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1
return CollectReport(
collector.nodeid, "passed", longrepr=None, result=[]
)
return None
class LFPlugin:
"""Plugin which implements the --lf (run last-failing) option."""
def __init__(self, config: Config) -> None:
self.config = config
active_keys = "lf", "failedfirst"
self.active = any(config.getoption(key) for key in active_keys)
assert config.cache
self.lastfailed: dict[str, bool] = config.cache.get("cache/lastfailed", {})
self._previously_failed_count: int | None = None
self._report_status: str | None = None
self._skipped_files = 0 # count skipped files during collection due to --lf
if config.getoption("lf"):
self._last_failed_paths = self.get_last_failed_paths()
config.pluginmanager.register(
LFPluginCollWrapper(self), "lfplugin-collwrapper"
)
def get_last_failed_paths(self) -> set[Path]:
"""Return a set with all Paths of the previously failed nodeids and
their parents."""
rootpath = self.config.rootpath
result = set()
for nodeid in self.lastfailed:
path = rootpath / nodeid.split("::")[0]
result.add(path)
result.update(path.parents)
return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self) -> str | None:
if self.active and self.config.get_verbosity() >= 0:
return f"run-last-failure: {self._report_status}"
return None
def pytest_runtest_logreport(self, report: TestReport) -> None:
if (report.when == "call" and report.passed) or report.skipped:
self.lastfailed.pop(report.nodeid, None)
elif report.failed:
self.lastfailed[report.nodeid] = True
def pytest_collectreport(self, report: CollectReport) -> None:
passed = report.outcome in ("passed", "skipped")
if passed:
if report.nodeid in self.lastfailed:
self.lastfailed.pop(report.nodeid)
self.lastfailed.update((item.nodeid, True) for item in report.result)
else:
self.lastfailed[report.nodeid] = True
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems(
self, config: Config, items: list[nodes.Item]
) -> Generator[None]:
res = yield
if not self.active:
return res
if self.lastfailed:
previously_failed = []
previously_passed = []
for item in items:
if item.nodeid in self.lastfailed:
previously_failed.append(item)
else:
previously_passed.append(item)
self._previously_failed_count = len(previously_failed)
if not previously_failed:
# Running a subset of all tests with recorded failures
# only outside of it.
self._report_status = (
f"{len(self.lastfailed)} known failures not in selected tests"
)
else:
if self.config.getoption("lf"):
items[:] = previously_failed
config.hook.pytest_deselected(items=previously_passed)
else: # --failedfirst
items[:] = previously_failed + previously_passed
noun = "failure" if self._previously_failed_count == 1 else "failures"
suffix = " first" if self.config.getoption("failedfirst") else ""
self._report_status = (
f"rerun previous {self._previously_failed_count} {noun}{suffix}"
)
if self._skipped_files > 0:
files_noun = "file" if self._skipped_files == 1 else "files"
self._report_status += f" (skipped {self._skipped_files} {files_noun})"
else:
self._report_status = "no previously failed tests, "
if self.config.getoption("last_failed_no_failures") == "none":
self._report_status += "deselecting all items."
config.hook.pytest_deselected(items=items[:])
items[:] = []
else:
self._report_status += "not deselecting items."
return res
def pytest_sessionfinish(self, session: Session) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
assert config.cache is not None
saved_lastfailed = config.cache.get("cache/lastfailed", {})
if saved_lastfailed != self.lastfailed:
config.cache.set("cache/lastfailed", self.lastfailed)
class NFPlugin:
"""Plugin which implements the --nf (run new-first) option."""
def __init__(self, config: Config) -> None:
self.config = config
self.active = config.option.newfirst
assert config.cache is not None
self.cached_nodeids = set(config.cache.get("cache/nodeids", []))
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems(self, items: list[nodes.Item]) -> Generator[None]:
res = yield
if self.active:
new_items: dict[str, nodes.Item] = {}
other_items: dict[str, nodes.Item] = {}
for item in items:
if item.nodeid not in self.cached_nodeids:
new_items[item.nodeid] = item
else:
other_items[item.nodeid] = item
items[:] = self._get_increasing_order(
new_items.values()
) + self._get_increasing_order(other_items.values())
self.cached_nodeids.update(new_items)
else:
self.cached_nodeids.update(item.nodeid for item in items)
return res
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> list[nodes.Item]:
return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True)
def pytest_sessionfinish(self) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
if config.getoption("collectonly"):
return
assert config.cache is not None
config.cache.set("cache/nodeids", sorted(self.cached_nodeids))
def pytest_addoption(parser: Parser) -> None:
"""Add command-line options for cache functionality.
:param parser: Parser object to add command-line options to.
"""
group = parser.getgroup("general")
group.addoption(
"--lf",
"--last-failed",
action="store_true",
dest="lf",
help="Rerun only the tests that failed at the last run (or all if none failed)",
)
group.addoption(
"--ff",
"--failed-first",
action="store_true",
dest="failedfirst",
help="Run all tests, but run the last failures first. "
"This may re-order tests and thus lead to "
"repeated fixture setup/teardown.",
)
group.addoption(
"--nf",
"--new-first",
action="store_true",
dest="newfirst",
help="Run tests from new files first, then the rest of the tests "
"sorted by file mtime",
)
group.addoption(
"--cache-show",
action="append",
nargs="?",
dest="cacheshow",
help=(
"Show cache contents, don't perform collection or tests. "
"Optional argument: glob (default: '*')."
),
)
group.addoption(
"--cache-clear",
action="store_true",
dest="cacheclear",
help="Remove all cache contents at start of test run",
)
cache_dir_default = ".pytest_cache"
if "TOX_ENV_DIR" in os.environ:
cache_dir_default = os.path.join(os.environ["TOX_ENV_DIR"], cache_dir_default)
parser.addini("cache_dir", default=cache_dir_default, help="Cache directory path")
group.addoption(
"--lfnf",
"--last-failed-no-failures",
action="store",
dest="last_failed_no_failures",
choices=("all", "none"),
default="all",
help="With ``--lf``, determines whether to execute tests when there "
"are no previously (known) failures or when no "
"cached ``lastfailed`` data was found. "
"``all`` (the default) runs the full test suite again. "
"``none`` just emits a message about no known failures and exits successfully.",
)
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.cacheshow and not config.option.help:
from _pytest.main import wrap_session
return wrap_session(config, cacheshow)
return None
@hookimpl(tryfirst=True)
def pytest_configure(config: Config) -> None:
"""Configure cache system and register related plugins.
Creates the Cache instance and registers the last-failed (LFPlugin)
and new-first (NFPlugin) plugins with the plugin manager.
:param config: pytest configuration object.
"""
config.cache = Cache.for_config(config, _ispytest=True)
config.pluginmanager.register(LFPlugin(config), "lfplugin")
config.pluginmanager.register(NFPlugin(config), "nfplugin")
@fixture
def cache(request: FixtureRequest) -> Cache:
"""Return a cache object that can persist state between testing sessions.
cache.get(key, default)
cache.set(key, value)
Keys must be ``/`` separated strings, where the first part is usually the
name of your plugin or application to avoid clashes with other cache users.
Values can be any object handled by the json stdlib module.
"""
assert request.config.cache is not None
return request.config.cache
def pytest_report_header(config: Config) -> str | None:
"""Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
assert config.cache is not None
cachedir = config.cache._cachedir
# TODO: evaluate generating upward relative paths
# starting with .., ../.. if sensible
try:
displaypath = cachedir.relative_to(config.rootpath)
except ValueError:
displaypath = cachedir
return f"cachedir: {displaypath}"
return None
def cacheshow(config: Config, session: Session) -> int:
"""Display cache contents when --cache-show is used.
Shows cached values and directories matching the specified glob pattern
(default: '*'). Displays cache location, cached test results, and
any cached directories created by plugins.
:param config: pytest configuration object.
:param session: pytest session object.
:returns: Exit code (0 for success).
"""
from pprint import pformat
assert config.cache is not None
tw = TerminalWriter()
tw.line("cachedir: " + str(config.cache._cachedir))
if not config.cache._cachedir.is_dir():
tw.line("cache is empty")
return 0
glob = config.option.cacheshow[0]
if glob is None:
glob = "*"
dummy = object()
basedir = config.cache._cachedir
vdir = basedir / Cache._CACHE_PREFIX_VALUES
tw.sep("-", f"cache values for {glob!r}")
for valpath in sorted(x for x in vdir.rglob(glob) if x.is_file()):
key = str(valpath.relative_to(vdir))
val = config.cache.get(key, dummy)
if val is dummy:
tw.line(f"{key} contains unreadable content, will be ignored")
else:
tw.line(f"{key} contains:")
for line in pformat(val).splitlines():
tw.line(" " + line)
ddir = basedir / Cache._CACHE_PREFIX_DIRS
if ddir.is_dir():
contents = sorted(ddir.rglob(glob))
tw.sep("-", f"cache directories for {glob!r}")
for p in contents:
# if p.is_dir():
# print("%s/" % p.relative_to(basedir))
if p.is_file():
key = str(p.relative_to(basedir))
tw.line(f"{key} is a file of length {p.stat().st_size}")
return 0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,314 @@
# mypy: allow-untyped-defs
"""Python version compatibility code and random general utilities."""
from __future__ import annotations
from collections.abc import Callable
import enum
import functools
import inspect
from inspect import Parameter
from inspect import Signature
import os
from pathlib import Path
import sys
from typing import Any
from typing import Final
from typing import NoReturn
import py
if sys.version_info >= (3, 14):
from annotationlib import Format
#: constant to prepare valuing pylib path replacements/lazy proxies later on
# intended for removal in pytest 8.0 or 9.0
# fmt: off
# intentional space to create a fake difference for the verification
LEGACY_PATH = py.path. local
# fmt: on
def legacy_path(path: str | os.PathLike[str]) -> LEGACY_PATH:
"""Internal wrapper to prepare lazy proxies for legacy_path instances"""
return LEGACY_PATH(path)
# fmt: off
# Singleton type for NOTSET, as described in:
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
class NotSetType(enum.Enum):
token = 0
NOTSET: Final = NotSetType.token
# fmt: on
def iscoroutinefunction(func: object) -> bool:
"""Return True if func is a coroutine function (a function defined with async
def syntax, and doesn't contain yield), or a function decorated with
@asyncio.coroutine.
Note: copied and modified from Python 3.5's builtin coroutines.py to avoid
importing asyncio directly, which in turns also initializes the "logging"
module as a side-effect (see issue #8).
"""
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
def is_async_function(func: object) -> bool:
"""Return True if the given function seems to be an async function or
an async generator."""
return iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
def signature(obj: Callable[..., Any]) -> Signature:
"""Return signature without evaluating annotations."""
if sys.version_info >= (3, 14):
return inspect.signature(obj, annotation_format=Format.STRING)
return inspect.signature(obj)
def getlocation(function, curdir: str | os.PathLike[str] | None = None) -> str:
function = get_real_func(function)
fn = Path(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
if curdir is not None:
try:
relfn = fn.relative_to(curdir)
except ValueError:
pass
else:
return f"{relfn}:{lineno + 1}"
return f"{fn}:{lineno + 1}"
def num_mock_patch_args(function) -> int:
"""Return number of arguments used up by mock arguments (if any)."""
patchings = getattr(function, "patchings", None)
if not patchings:
return 0
mock_sentinel = getattr(sys.modules.get("mock"), "DEFAULT", object())
ut_mock_sentinel = getattr(sys.modules.get("unittest.mock"), "DEFAULT", object())
return len(
[
p
for p in patchings
if not p.attribute_name
and (p.new is mock_sentinel or p.new is ut_mock_sentinel)
]
)
def getfuncargnames(
function: Callable[..., object],
*,
name: str = "",
cls: type | None = None,
) -> tuple[str, ...]:
"""Return the names of a function's mandatory arguments.
Should return the names of all function arguments that:
* Aren't bound to an instance or type as in instance or class methods.
* Don't have default values.
* Aren't bound with functools.partial.
* Aren't replaced with mocks.
The cls arguments indicate that the function should be treated as a bound
method even though it's not unless the function is a static method.
The name parameter should be the original name in which the function was collected.
"""
# TODO(RonnyPfannschmidt): This function should be refactored when we
# revisit fixtures. The fixture mechanism should ask the node for
# the fixture names, and not try to obtain directly from the
# function object well after collection has occurred.
# The parameters attribute of a Signature object contains an
# ordered mapping of parameter names to Parameter instances. This
# creates a tuple of the names of the parameters that don't have
# defaults.
try:
parameters = signature(function).parameters.values()
except (ValueError, TypeError) as e:
from _pytest.outcomes import fail
fail(
f"Could not determine arguments of {function!r}: {e}",
pytrace=False,
)
arg_names = tuple(
p.name
for p in parameters
if (
p.kind is Parameter.POSITIONAL_OR_KEYWORD
or p.kind is Parameter.KEYWORD_ONLY
)
and p.default is Parameter.empty
)
if not name:
name = function.__name__
# If this function should be treated as a bound method even though
# it's passed as an unbound method or function, and its first parameter
# wasn't defined as positional only, remove the first parameter name.
if not any(p.kind is Parameter.POSITIONAL_ONLY for p in parameters) and (
# Not using `getattr` because we don't want to resolve the staticmethod.
# Not using `cls.__dict__` because we want to check the entire MRO.
cls
and not isinstance(
inspect.getattr_static(cls, name, default=None), staticmethod
)
):
arg_names = arg_names[1:]
# Remove any names that will be replaced with mocks.
if hasattr(function, "__wrapped__"):
arg_names = arg_names[num_mock_patch_args(function) :]
return arg_names
def get_default_arg_names(function: Callable[..., Any]) -> tuple[str, ...]:
# Note: this code intentionally mirrors the code at the beginning of
# getfuncargnames, to get the arguments which were excluded from its result
# because they had default values.
return tuple(
p.name
for p in signature(function).parameters.values()
if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
and p.default is not Parameter.empty
)
_non_printable_ascii_translate_table = {
i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127)
}
_non_printable_ascii_translate_table.update(
{ord("\t"): "\\t", ord("\r"): "\\r", ord("\n"): "\\n"}
)
def ascii_escaped(val: bytes | str) -> str:
r"""If val is pure ASCII, return it as an str, otherwise, escape
bytes objects into a sequence of escaped bytes:
b'\xc3\xb4\xc5\xd6' -> r'\xc3\xb4\xc5\xd6'
and escapes strings into a sequence of escaped unicode ids, e.g.:
r'4\nV\U00043efa\x0eMXWB\x1e\u3028\u15fd\xcd\U0007d944'
Note:
The obvious "v.decode('unicode-escape')" will return
valid UTF-8 unicode if it finds them in bytes, but we
want to return escaped bytes for any byte, even if they match
a UTF-8 string.
"""
if isinstance(val, bytes):
ret = val.decode("ascii", "backslashreplace")
else:
ret = val.encode("unicode_escape").decode("ascii")
return ret.translate(_non_printable_ascii_translate_table)
def get_real_func(obj):
"""Get the real function object of the (possibly) wrapped object by
:func:`functools.wraps`, or :func:`functools.partial`."""
obj = inspect.unwrap(obj)
if isinstance(obj, functools.partial):
obj = obj.func
return obj
def getimfunc(func):
try:
return func.__func__
except AttributeError:
return func
def safe_getattr(object: Any, name: str, default: Any) -> Any:
"""Like getattr but return default upon any Exception or any OutcomeException.
Attribute access can potentially fail for 'evil' Python objects.
See issue #214.
It catches OutcomeException because of #2490 (issue #580), new outcomes
are derived from BaseException instead of Exception (for more details
check #2707).
"""
from _pytest.outcomes import TEST_OUTCOME
try:
return getattr(object, name, default)
except TEST_OUTCOME:
return default
def safe_isclass(obj: object) -> bool:
"""Ignore any exception via isinstance on Python 3."""
try:
return inspect.isclass(obj)
except Exception:
return False
def get_user_id() -> int | None:
"""Return the current process's real user id or None if it could not be
determined.
:return: The user id or None if it could not be determined.
"""
# mypy follows the version and platform checking expectation of PEP 484:
# https://mypy.readthedocs.io/en/stable/common_issues.html?highlight=platform#python-version-and-system-platform-checks
# Containment checks are too complex for mypy v1.5.0 and cause failure.
if sys.platform == "win32" or sys.platform == "emscripten":
# win32 does not have a getuid() function.
# Emscripten has a return 0 stub.
return None
else:
# On other platforms, a return value of -1 is assumed to indicate that
# the current process's real user id could not be determined.
ERROR = -1
uid = os.getuid()
return uid if uid != ERROR else None
if sys.version_info >= (3, 11):
from typing import assert_never
else:
def assert_never(value: NoReturn) -> NoReturn:
assert False, f"Unhandled value: {value} ({type(value).__name__})"
class CallableBool:
"""
A bool-like object that can also be called, returning its true/false value.
Used for backwards compatibility in cases where something was supposed to be a method
but was implemented as a simple attribute by mistake (see `TerminalReporter.isatty`).
Do not use in new code.
"""
def __init__(self, value: bool) -> None:
self._value = value
def __bool__(self) -> bool:
return self._value
def __call__(self) -> bool:
return self._value
def running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
# Only enable CI mode if one of these env variables is defined and non-empty.
# Note: review `regendoc` tox env in case this list is changed.
env_vars = ["CI", "BUILD_NUMBER"]
return any(os.environ.get(var) for var in env_vars)

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More