This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,5 @@
"""
The `pip_audit` APIs.
"""
__version__ = "2.9.0"

View File

@@ -0,0 +1,8 @@
"""
The `python -m pip_audit` entrypoint.
"""
if __name__ == "__main__": # pragma: no cover
from pip_audit._cli import audit
audit()

View File

@@ -0,0 +1,96 @@
"""
Core auditing APIs.
"""
from __future__ import annotations
import logging
from collections.abc import Iterator
from dataclasses import dataclass
from pip_audit._dependency_source import DependencySource
from pip_audit._service import Dependency, VulnerabilityResult, VulnerabilityService
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class AuditOptions:
"""
Settings the control the behavior of an `Auditor` instance.
"""
dry_run: bool = False
class Auditor:
"""
The core class of the `pip-audit` API.
For a given dependency source and vulnerability service, supply a mapping of dependencies to
known vulnerabilities.
"""
def __init__(
self,
service: VulnerabilityService,
options: AuditOptions = AuditOptions(),
):
"""
Create a new auditor. Auditors start with no dependencies to audit;
each `audit` step is fed a `DependencySource`.
The behavior of the auditor can be optionally tweaked with the `options`
parameter.
"""
self._service = service
self._options = options
def audit(
self, source: DependencySource
) -> Iterator[tuple[Dependency, list[VulnerabilityResult]]]:
"""
Perform the auditing step, collecting dependencies from `source`.
Individual vulnerability results are uniqued based on their `aliases` sets:
any two results for the same dependency that share an alias are collapsed
into a single result with a union of all aliases.
`PYSEC`-identified results are given priority over other results.
"""
specs = source.collect()
if self._options.dry_run:
# Drain the iterator in dry-run mode.
logger.info(f"Dry run: would have audited {len(list(specs))} packages")
yield from ()
else:
for dep, vulns in self._service.query_all(specs):
unique_vulns: list[VulnerabilityResult] = []
seen_aliases: set[str] = set()
# First pass, add all PYSEC vulnerabilities and track their
# alias sets.
for v in vulns:
if not v.id.startswith("PYSEC"):
continue
seen_aliases.update(v.aliases | {v.id})
unique_vulns.append(v)
# Second pass: add any non-PYSEC vulnerabilities.
for v in vulns:
# If we've already seen this vulnerability by another name,
# don't add it. Instead, find the previous result and update
# its alias set.
if seen_aliases.intersection(v.aliases | {v.id}):
idx, previous = next(
(i, p) for (i, p) in enumerate(unique_vulns) if p.alias_of(v)
)
unique_vulns[idx] = previous.merge_aliases(v)
continue
seen_aliases.update(v.aliases | {v.id})
unique_vulns.append(v)
yield (dep, unique_vulns)

View File

@@ -0,0 +1,178 @@
"""
Caching middleware for `pip-audit`.
"""
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import sys
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any
import pip_api
import requests
from cachecontrol import CacheControl
from cachecontrol.caches import FileCache
from packaging.version import Version
from platformdirs import user_cache_path
from pip_audit._service.interface import ServiceError
logger = logging.getLogger(__name__)
# The `cache dir` command was added to `pip` as of 20.1 so we should check before trying to use it
# to discover the `pip` HTTP cache
_MINIMUM_PIP_VERSION = Version("20.1")
_PIP_VERSION = Version(str(pip_api.PIP_VERSION))
_PIP_AUDIT_LEGACY_INTERNAL_CACHE = Path.home() / ".pip-audit-cache"
def _get_pip_cache() -> Path:
# Unless the cache directory is specifically set by the `--cache-dir` option, we try to share
# the `pip` HTTP cache
cmd = [sys.executable, "-m", "pip", "cache", "dir"]
try:
process = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
except subprocess.CalledProcessError as cpe: # pragma: no cover
# NOTE: This should only happen if pip's cache has been explicitly disabled,
# which we check for in the caller (via `PIP_NO_CACHE_DIR`).
raise ServiceError(f"Failed to query the `pip` HTTP cache directory: {cmd}") from cpe
cache_dir = process.stdout.decode("utf-8").strip("\n")
http_cache_dir = Path(cache_dir) / "http"
return http_cache_dir
def _get_cache_dir(custom_cache_dir: Path | None, *, use_pip: bool = True) -> Path:
"""
Returns a directory path suitable for HTTP caching.
The directory is **not** guaranteed to exist.
`use_pip` tells the function to prefer `pip`'s pre-existing cache,
**unless** `PIP_NO_CACHE_DIR` is present in the environment.
"""
# If the user has explicitly requested a directory, pass it through unscathed.
if custom_cache_dir is not None:
return custom_cache_dir
# Retrieve pip-audit's default internal cache using `platformdirs`.
pip_audit_cache_dir = user_cache_path("pip-audit", appauthor=False, ensure_exists=True)
# If the retrieved cache isn't the legacy one, try to delete the old cache if it exists.
if (
_PIP_AUDIT_LEGACY_INTERNAL_CACHE.exists()
and pip_audit_cache_dir != _PIP_AUDIT_LEGACY_INTERNAL_CACHE
):
shutil.rmtree(_PIP_AUDIT_LEGACY_INTERNAL_CACHE)
# Respect pip's PIP_NO_CACHE_DIR environment setting.
if use_pip and not os.getenv("PIP_NO_CACHE_DIR"):
pip_cache_dir = _get_pip_cache() if _PIP_VERSION >= _MINIMUM_PIP_VERSION else None
if pip_cache_dir is not None:
return pip_cache_dir
else:
logger.warning(
f"pip {_PIP_VERSION} doesn't support the `cache dir` subcommand, "
f"using {pip_audit_cache_dir} instead"
)
return pip_audit_cache_dir
else:
return pip_audit_cache_dir
class _SafeFileCache(FileCache):
"""
A rough mirror of `pip`'s `SafeFileCache` that *should* be runtime-compatible
with `pip` (i.e., does not interfere with `pip` when it shares the same
caching directory as a running `pip` process).
"""
def __init__(self, directory: Path):
self._logged_warning = False
super().__init__(str(directory))
def get(self, key: str) -> Any | None:
try:
return super().get(key)
except Exception as e: # pragma: no cover
if not self._logged_warning:
logger.warning(
f"Failed to read from cache directory, performance may be degraded: {e}"
)
self._logged_warning = True
return None
def set(self, key: str, value: bytes, expires: Any | None = None) -> None:
try:
self._set_impl(key, value)
except Exception as e: # pragma: no cover
if not self._logged_warning:
logger.warning(
f"Failed to write to cache directory, performance may be degraded: {e}"
)
self._logged_warning = True
def _set_impl(self, key: str, value: bytes) -> None:
name: str = super()._fn(key)
# Make sure the directory exists
try:
os.makedirs(os.path.dirname(name), self.dirmode)
except OSError: # pragma: no cover
pass
# We don't want to use lock files since `pip` isn't going to recognise those. We should
# write to the cache in a similar way to how `pip` does it. We create a temporary file,
# then atomically replace the actual cache key's filename with it. This ensures
# that other concurrent `pip` or `pip-audit` instances don't read partial data.
with NamedTemporaryFile(delete=False, dir=os.path.dirname(name)) as io:
io.write(value)
# NOTE(ww): Similar to what `pip` does in `adjacent_tmp_file`.
io.flush()
os.fsync(io.fileno())
# NOTE(ww): Windows won't let us rename the temporary file until it's closed,
# which is why we call `os.replace()` here rather than in the `with` block above.
os.replace(io.name, name)
def delete(self, key: str) -> None: # pragma: no cover
try:
super().delete(key)
except Exception as e:
if not self._logged_warning:
logger.warning(
f"Failed to delete file from cache directory, performance may be degraded: {e}"
)
self._logged_warning = True
def caching_session(cache_dir: Path | None, *, use_pip: bool = False) -> requests.Session:
"""
Return a `requests` style session, with suitable caching middleware.
Uses the given `cache_dir` for the HTTP cache.
`use_pip` determines how the fallback cache directory is determined, if `cache_dir` is None.
When `use_pip` is `False`, `caching_session` will use a `pip-audit` internal cache directory.
When `use_pip` is `True`, `caching_session` will attempt to discover `pip`'s cache
directory, falling back on the internal `pip-audit` cache directory if the user's
version of `pip` is too old.
"""
# We limit the number of redirects to 5, since the services we connect to
# should really never redirect more than once or twice.
inner_session = requests.Session()
inner_session.max_redirects = 5
return CacheControl(
inner_session,
cache=_SafeFileCache(_get_cache_dir(cache_dir, use_pip=use_pip)),
)

View File

@@ -0,0 +1,636 @@
"""
Command-line entrypoints for `pip-audit`.
"""
from __future__ import annotations
import argparse
import enum
import logging
import os
import sys
from collections.abc import Iterator
from contextlib import ExitStack, contextmanager
from pathlib import Path
from typing import IO, NoReturn, cast
from pip_audit import __version__
from pip_audit._audit import AuditOptions, Auditor
from pip_audit._dependency_source import (
DependencySource,
DependencySourceError,
PipSource,
PyProjectSource,
RequirementSource,
)
from pip_audit._dependency_source.pylock import PyLockSource
from pip_audit._fix import ResolvedFixVersion, SkippedFixVersion, resolve_fix_versions
from pip_audit._format import (
ColumnsFormat,
CycloneDxFormat,
JsonFormat,
MarkdownFormat,
VulnerabilityFormat,
)
from pip_audit._service import OsvService, PyPIService, VulnerabilityService
from pip_audit._service.interface import ConnectionError as VulnServiceConnectionError
from pip_audit._service.interface import ResolvedDependency, SkippedDependency
from pip_audit._state import AuditSpinner, AuditState
from pip_audit._util import assert_never
logging.basicConfig()
logger = logging.getLogger(__name__)
# NOTE: We configure the top package logger, rather than the root logger,
# to avoid overly verbose logging in third-party code by default.
package_logger = logging.getLogger("pip_audit")
package_logger.setLevel(os.environ.get("PIP_AUDIT_LOGLEVEL", "INFO").upper())
@contextmanager
def _output_io(name: Path) -> Iterator[IO[str]]: # pragma: no cover
"""
A context managing wrapper for pip-audit's `--output` flag. This allows us
to avoid `argparse.FileType`'s "eager" file creation, which is generally
the wrong/unexpected behavior when dealing with fallible processes.
"""
if str(name) in {"stdout", "-"}:
yield sys.stdout
else:
with name.open("w") as io:
yield io
@enum.unique
class OutputFormatChoice(str, enum.Enum):
"""
Output formats supported by the `pip-audit` CLI.
"""
Columns = "columns"
Json = "json"
CycloneDxJson = "cyclonedx-json"
CycloneDxXml = "cyclonedx-xml"
Markdown = "markdown"
def to_format(self, output_desc: bool, output_aliases: bool) -> VulnerabilityFormat:
if self is OutputFormatChoice.Columns:
return ColumnsFormat(output_desc, output_aliases)
elif self is OutputFormatChoice.Json:
return JsonFormat(output_desc, output_aliases)
elif self is OutputFormatChoice.CycloneDxJson:
return CycloneDxFormat(inner_format=CycloneDxFormat.InnerFormat.Json)
elif self is OutputFormatChoice.CycloneDxXml:
return CycloneDxFormat(inner_format=CycloneDxFormat.InnerFormat.Xml)
elif self is OutputFormatChoice.Markdown:
return MarkdownFormat(output_desc, output_aliases)
else:
assert_never(self) # pragma: no cover
def __str__(self) -> str:
return self.value
@enum.unique
class VulnerabilityServiceChoice(str, enum.Enum):
"""
Python vulnerability services supported by `pip-audit`.
"""
Osv = "osv"
Pypi = "pypi"
def to_service(self, timeout: int, cache_dir: Path | None) -> VulnerabilityService:
if self is VulnerabilityServiceChoice.Osv:
return OsvService(cache_dir, timeout)
elif self is VulnerabilityServiceChoice.Pypi:
return PyPIService(cache_dir, timeout)
else:
assert_never(self) # pragma: no cover
def __str__(self) -> str:
return self.value
@enum.unique
class VulnerabilityDescriptionChoice(str, enum.Enum):
"""
Whether or not vulnerability descriptions should be added to the `pip-audit` output.
"""
On = "on"
Off = "off"
Auto = "auto"
def to_bool(self, format_: OutputFormatChoice) -> bool:
if self is VulnerabilityDescriptionChoice.On:
return True
elif self is VulnerabilityDescriptionChoice.Off:
return False
elif self is VulnerabilityDescriptionChoice.Auto:
return bool(format_ is OutputFormatChoice.Json)
else:
assert_never(self) # pragma: no cover
def __str__(self) -> str:
return self.value
@enum.unique
class VulnerabilityAliasChoice(str, enum.Enum):
"""
Whether or not vulnerability aliases should be added to the `pip-audit` output.
"""
On = "on"
Off = "off"
Auto = "auto"
def to_bool(self, format_: OutputFormatChoice) -> bool:
if self is VulnerabilityAliasChoice.On:
return True
elif self is VulnerabilityAliasChoice.Off:
return False
elif self is VulnerabilityAliasChoice.Auto:
return bool(format_ is OutputFormatChoice.Json)
else:
assert_never(self) # pragma: no cover
def __str__(self) -> str:
return self.value
@enum.unique
class ProgressSpinnerChoice(str, enum.Enum):
"""
Whether or not `pip-audit` should display a progress spinner.
"""
On = "on"
Off = "off"
def __bool__(self) -> bool:
return self is ProgressSpinnerChoice.On
def __str__(self) -> str:
return self.value
def _enum_help(msg: str, e: type[enum.Enum]) -> str: # pragma: no cover
"""
Render a `--help`-style string for the given enumeration.
"""
return f"{msg} (choices: {', '.join(str(v) for v in e)})"
def _fatal(msg: str) -> NoReturn: # pragma: no cover
"""
Log a fatal error to the standard error stream and exit.
"""
# NOTE: We buffer the logger when the progress spinner is active,
# ensuring that the fatal message is formatted on its own line.
logger.error(msg)
sys.exit(1)
def _parser() -> argparse.ArgumentParser: # pragma: no cover
parser = argparse.ArgumentParser(
prog="pip-audit",
description="audit the Python environment for dependencies with known vulnerabilities",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
dep_source_args = parser.add_mutually_exclusive_group()
parser.add_argument("-V", "--version", action="version", version=f"%(prog)s {__version__}")
parser.add_argument(
"-l",
"--local",
action="store_true",
help="show only results for dependencies in the local environment",
)
dep_source_args.add_argument(
"-r",
"--requirement",
type=Path,
metavar="REQUIREMENT",
action="append",
dest="requirements",
help="audit the given requirements file; this option can be used multiple times",
)
dep_source_args.add_argument(
"project_path",
type=Path,
nargs="?",
help="audit a local Python project at the given path",
)
parser.add_argument(
"--locked",
action="store_true",
help="audit lock files from the local Python project. This "
"flag only applies to auditing from project paths",
)
parser.add_argument(
"-f",
"--format",
type=OutputFormatChoice,
choices=OutputFormatChoice,
default=os.environ.get("PIP_AUDIT_FORMAT", OutputFormatChoice.Columns),
metavar="FORMAT",
help=_enum_help("the format to emit audit results in", OutputFormatChoice),
)
parser.add_argument(
"-s",
"--vulnerability-service",
type=VulnerabilityServiceChoice,
choices=VulnerabilityServiceChoice,
default=os.environ.get("PIP_AUDIT_VULNERABILITY_SERVICE", VulnerabilityServiceChoice.Pypi),
metavar="SERVICE",
help=_enum_help(
"the vulnerability service to audit dependencies against",
VulnerabilityServiceChoice,
),
)
parser.add_argument(
"-d",
"--dry-run",
action="store_true",
help="without `--fix`: collect all dependencies but do not perform the auditing step; "
"with `--fix`: perform the auditing step but do not perform any fixes",
)
parser.add_argument(
"-S",
"--strict",
action="store_true",
help="fail the entire audit if dependency collection fails on any dependency",
)
parser.add_argument(
"--desc",
type=VulnerabilityDescriptionChoice,
choices=VulnerabilityDescriptionChoice,
nargs="?",
const=VulnerabilityDescriptionChoice.On,
default=os.environ.get("PIP_AUDIT_DESC", VulnerabilityDescriptionChoice.Auto),
help="include a description for each vulnerability; "
"`auto` defaults to `on` for the `json` format. This flag has no "
"effect on the `cyclonedx-json` or `cyclonedx-xml` formats.",
)
parser.add_argument(
"--aliases",
type=VulnerabilityAliasChoice,
choices=VulnerabilityAliasChoice,
nargs="?",
const=VulnerabilityAliasChoice.On,
default=VulnerabilityAliasChoice.Auto,
help="includes alias IDs for each vulnerability; "
"`auto` defaults to `on` for the `json` format. This flag has no "
"effect on the `cyclonedx-json` or `cyclonedx-xml` formats.",
)
parser.add_argument(
"--cache-dir",
type=Path,
help="the directory to use as an HTTP cache for PyPI; uses the `pip` HTTP cache by default",
)
parser.add_argument(
"--progress-spinner",
type=ProgressSpinnerChoice,
choices=ProgressSpinnerChoice,
default=os.environ.get("PIP_AUDIT_PROGRESS_SPINNER", ProgressSpinnerChoice.On),
help="display a progress spinner",
)
parser.add_argument(
"--timeout",
type=int,
default=15,
help="set the socket timeout", # Match the `pip` default
)
dep_source_args.add_argument(
"--path",
type=Path,
metavar="PATH",
action="append",
dest="paths",
default=[],
help="restrict to the specified installation path for auditing packages; "
"this option can be used multiple times",
)
parser.add_argument(
"-v",
"--verbose",
action="count",
default=0,
help="run with additional debug logging; supply multiple times to increase verbosity",
)
parser.add_argument(
"--fix",
action="store_true",
help="automatically upgrade dependencies with known vulnerabilities",
)
parser.add_argument(
"--require-hashes",
action="store_true",
help="require a hash to check each requirement against, for repeatable audits; this option "
"is implied when any package in a requirements file has a `--hash` option.",
)
parser.add_argument(
"--index-url",
type=str,
help="base URL of the Python Package Index; this should point to a repository compliant "
"with PEP 503 (the simple repository API); this will be resolved by pip if not specified",
)
parser.add_argument(
"--extra-index-url",
type=str,
metavar="URL",
action="append",
dest="extra_index_urls",
default=[],
help="extra URLs of package indexes to use in addition to `--index-url`; should follow the "
"same rules as `--index-url`",
)
parser.add_argument(
"--skip-editable",
action="store_true",
help="don't audit packages that are marked as editable",
)
parser.add_argument(
"--no-deps",
action="store_true",
help="don't perform any dependency resolution; requires all requirements are pinned "
"to an exact version",
)
parser.add_argument(
"-o",
"--output",
type=Path,
metavar="FILE",
help="output results to the given file",
default=os.environ.get("PIP_AUDIT_OUTPUT", "stdout"),
)
parser.add_argument(
"--ignore-vuln",
type=str,
metavar="ID",
action="append",
dest="ignore_vulns",
default=[],
help=(
"ignore a specific vulnerability by its vulnerability ID; "
"this option can be used multiple times"
),
)
parser.add_argument(
"--disable-pip",
action="store_true",
help="don't use `pip` for dependency resolution; "
"this can only be used with hashed requirements files or if the `--no-deps` flag has been "
"provided",
)
return parser
def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace: # pragma: no cover
args = parser.parse_args()
# Configure logging upfront, so that we don't miss anything.
if args.verbose >= 1:
package_logger.setLevel("DEBUG")
if args.verbose >= 2:
logging.getLogger().setLevel("DEBUG")
logger.debug(f"parsed arguments: {args}")
return args
def _dep_source_from_project_path(
project_path: Path, index_url: str, extra_index_urls: list[str], locked: bool, state: AuditState
) -> DependencySource: # pragma: no cover
# If the user has passed `--locked`, we check for `pylock.*.toml` files.
if locked:
all_pylocks = list(project_path.glob("pylock.*.toml"))
generic_pylock = project_path / "pylock.toml"
if generic_pylock.is_file():
all_pylocks.append(generic_pylock)
if not all_pylocks:
_fatal(f"no lockfiles found in {project_path}")
return PyLockSource(all_pylocks)
# Check for a `pyproject.toml`
pyproject_path = project_path / "pyproject.toml"
if pyproject_path.is_file():
return PyProjectSource(
pyproject_path,
index_url=index_url,
extra_index_urls=extra_index_urls,
state=state,
)
# TODO: Checks for setup.py and other project files will go here.
_fatal(f"couldn't find a supported project file in {project_path}")
def audit() -> None: # pragma: no cover
"""
The primary entrypoint for `pip-audit`.
"""
parser = _parser()
args = _parse_args(parser)
service = args.vulnerability_service.to_service(args.timeout, args.cache_dir)
output_desc = args.desc.to_bool(args.format)
output_aliases = args.aliases.to_bool(args.format)
formatter = args.format.to_format(output_desc, output_aliases)
# Check for flags that are only valid with project paths
if args.project_path is None:
if args.locked:
parser.error("The --locked flag can only be used with a project path")
# Check for flags that are only valid with requirements files
if args.requirements is None:
if args.require_hashes:
parser.error("The --require-hashes flag can only be used with --requirement (-r)")
elif args.index_url:
parser.error("The --index-url flag can only be used with --requirement (-r)")
elif args.extra_index_urls:
parser.error("The --extra-index-url flag can only be used with --requirement (-r)")
elif args.no_deps:
parser.error("The --no-deps flag can only be used with --requirement (-r)")
elif args.disable_pip:
parser.error("The --disable-pip flag can only be used with --requirement (-r)")
# Nudge users to consider alternate workflows.
if args.require_hashes and args.no_deps:
logger.warning("The --no-deps flag is redundant when used with --require-hashes")
if args.require_hashes and isinstance(service, OsvService):
logger.warning(
"The --require-hashes flag with --service osv only enforces hash presence NOT hash "
"validity. Use --service pypi to enforce hash validity."
)
if args.no_deps:
logger.warning(
"--no-deps is supported, but users are encouraged to fully hash their "
"pinned dependencies"
)
logger.warning(
"Consider using a tool like `pip-compile`: "
"https://pip-tools.readthedocs.io/en/latest/#using-hashes"
)
with ExitStack() as stack:
actors = []
if args.progress_spinner:
actors.append(AuditSpinner("Collecting inputs"))
state = stack.enter_context(AuditState(members=actors))
source: DependencySource
if args.requirements is not None:
for req in args.requirements:
if not req.exists():
_fatal(f"invalid requirements input: {req}")
source = RequirementSource(
args.requirements,
require_hashes=args.require_hashes,
no_deps=args.no_deps,
disable_pip=args.disable_pip,
skip_editable=args.skip_editable,
index_url=args.index_url,
extra_index_urls=args.extra_index_urls,
state=state,
)
elif args.project_path is not None:
# NOTE: We'll probably want to support --skip-editable here,
# once PEP 660 is more widely supported: https://www.python.org/dev/peps/pep-0660/
# Determine which kind of project file exists in the project path
source = _dep_source_from_project_path(
args.project_path,
args.index_url,
args.extra_index_urls,
args.locked,
state,
)
else:
source = PipSource(
local=args.local,
paths=args.paths,
skip_editable=args.skip_editable,
state=state,
)
# `--dry-run` only affects the auditor if `--fix` is also not supplied,
# since the combination of `--dry-run` and `--fix` implies that the user
# wants to dry-run the "fix" step instead of the "audit" step
auditor = Auditor(service, options=AuditOptions(dry_run=args.dry_run and not args.fix))
result = {}
pkg_count = 0
vuln_count = 0
skip_count = 0
vuln_ignore_count = 0
vulns_to_ignore = set(args.ignore_vulns)
try:
for spec, vulns in auditor.audit(source):
if spec.is_skipped():
spec = cast(SkippedDependency, spec)
if args.strict:
_fatal(f"{spec.name}: {spec.skip_reason}")
else:
state.update_state(f"Skipping {spec.name}: {spec.skip_reason}")
skip_count += 1
else:
spec = cast(ResolvedDependency, spec)
logger.debug(f"Auditing {spec.name} ({spec.version})")
state.update_state(f"Auditing {spec.name} ({spec.version})")
if vulns_to_ignore:
filtered_vulns = [v for v in vulns if not v.has_any_id(vulns_to_ignore)]
vuln_ignore_count += len(vulns) - len(filtered_vulns)
vulns = filtered_vulns
result[spec] = vulns
if len(vulns) > 0:
pkg_count += 1
vuln_count += len(vulns)
except DependencySourceError as e:
_fatal(str(e))
except VulnServiceConnectionError as e:
# The most common source of connection errors is corporate blocking,
# so we offer a bit of advice.
logger.error(str(e))
_fatal(
"Tip: your network may be blocking this service. "
"Try another service with `-s SERVICE`"
)
# If the `--fix` flag has been applied, find a set of suitable fix versions and upgrade the
# dependencies at the source
fixes = list()
fixed_pkg_count = 0
fixed_vuln_count = 0
if args.fix:
for fix in resolve_fix_versions(service, result, state):
if args.dry_run:
if fix.is_skipped():
fix = cast(SkippedFixVersion, fix)
logger.info(
f"Dry run: would have skipped {fix.dep.name} "
f"upgrade because {fix.skip_reason}"
)
else:
fix = cast(ResolvedFixVersion, fix)
logger.info(f"Dry run: would have upgraded {fix.dep.name} to {fix.version}")
continue
if not fix.is_skipped():
fix = cast(ResolvedFixVersion, fix)
try:
source.fix(fix)
fixed_pkg_count += 1
fixed_vuln_count += len(result[fix.dep])
except DependencySourceError as dse:
skip_reason = str(dse)
logger.debug(skip_reason)
fix = SkippedFixVersion(fix.dep, skip_reason)
fixes.append(fix)
if vuln_count > 0:
if vuln_ignore_count:
ignored = f", ignored {vuln_ignore_count}"
else:
ignored = ""
summary_msg = (
f"Found {vuln_count} known "
f"{'vulnerability' if vuln_count == 1 else 'vulnerabilities'}"
f"{ignored} in {pkg_count} {'package' if pkg_count == 1 else 'packages'}"
)
if args.fix:
summary_msg += (
f" and fixed {fixed_vuln_count} "
f"{'vulnerability' if fixed_vuln_count == 1 else 'vulnerabilities'} "
f"in {fixed_pkg_count} "
f"{'package' if fixed_pkg_count == 1 else 'packages'}"
)
print(summary_msg, file=sys.stderr)
with _output_io(args.output) as io:
print(formatter.format(result, fixes), file=io)
if pkg_count != fixed_pkg_count:
sys.exit(1)
else:
summary_msg = "No known vulnerabilities found"
if vuln_ignore_count:
summary_msg += f", {vuln_ignore_count} ignored"
print(
summary_msg,
file=sys.stderr,
)
# If our output format is a "manifest" format we always emit it,
# even if nothing other than a dependency summary is present.
if skip_count > 0 or formatter.is_manifest:
with _output_io(args.output) as io:
print(formatter.format(result, fixes), file=io)

View File

@@ -0,0 +1,28 @@
"""
Dependency source interfaces and implementations for `pip-audit`.
"""
from .interface import (
PYPI_URL,
DependencyFixError,
DependencySource,
DependencySourceError,
InvalidRequirementSpecifier,
)
from .pip import PipSource, PipSourceError
from .pylock import PyLockSource
from .pyproject import PyProjectSource
from .requirement import RequirementSource
__all__ = [
"PYPI_URL",
"DependencyFixError",
"DependencySource",
"DependencySourceError",
"InvalidRequirementSpecifier",
"PipSource",
"PipSourceError",
"PyLockSource",
"PyProjectSource",
"RequirementSource",
]

View File

@@ -0,0 +1,69 @@
"""
Interfaces for interacting with "dependency sources", i.e. sources
of fully resolved Python dependency trees.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterator
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency
PYPI_URL = "https://pypi.org/simple/"
class DependencySource(ABC):
"""
Represents an abstract source of fully-resolved Python dependencies.
Individual concrete dependency sources (e.g. `pip list`) are expected
to subclass `DependencySource` and implement it in their terms.
"""
@abstractmethod
def collect(self) -> Iterator[Dependency]: # pragma: no cover
"""
Yield the dependencies in this source.
"""
raise NotImplementedError
@abstractmethod
def fix(self, fix_version: ResolvedFixVersion) -> None: # pragma: no cover
"""
Upgrade a dependency to the given fix version.
"""
raise NotImplementedError
class DependencySourceError(Exception):
"""
Raised when a `DependencySource` fails to provide its dependencies.
Concrete implementations are expected to subclass this exception to
provide more context.
"""
pass
class DependencyFixError(Exception):
"""
Raised when a `DependencySource` fails to perform a "fix" operation, i.e.
fails to upgrade a package to a different version.
Concrete implementations are expected to subclass this exception to provide
more context.
"""
pass
class InvalidRequirementSpecifier(DependencySourceError):
"""
A `DependencySourceError` specialized for the case of a non-PEP 440 requirements
specifier.
"""
pass

View File

@@ -0,0 +1,175 @@
"""
Collect the local environment's active dependencies via `pip list`, wrapped
by `pip-api`.
"""
import logging
import os
import subprocess
import sys
from collections.abc import Iterator, Sequence
from pathlib import Path
import pip_api
from packaging.version import InvalidVersion, Version
from pip_audit._dependency_source import (
DependencyFixError,
DependencySource,
DependencySourceError,
)
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency, ResolvedDependency, SkippedDependency
from pip_audit._state import AuditState
logger = logging.getLogger(__name__)
# Versions of `pip` prior to this version don't support `pip list -v --format=json`,
# which is our baseline for reliable output. We'll attempt to use versions before
# this one, but not before complaining about it.
_MINIMUM_RELIABLE_PIP_VERSION = Version("10.0.0b0")
# NOTE(ww): The round-trip assignment here is due to type confusion: `pip_api.PIP_VERSION`
# is a `Version` object, but it's a `pip_api._vendor.packaging.version.Version` instead
# of a `packaging.version.Version`. Recreating the version with the correct type
# ensures that our comparison operators work as expected.
_PIP_VERSION = Version(str(pip_api.PIP_VERSION))
class PipSource(DependencySource):
"""
Wraps `pip` (specifically `pip list`) as a dependency source.
"""
def __init__(
self,
*,
local: bool = False,
paths: Sequence[Path] = [],
skip_editable: bool = False,
state: AuditState = AuditState(),
) -> None:
"""
Create a new `PipSource`.
`local` determines whether to do a "local-only" list. If `True`, the
`DependencySource` does not expose globally installed packages.
`paths` is a list of locations to look for installed packages. If the
list is empty, the `DependencySource` will query the current Python
environment.
`skip_editable` controls whether dependencies marked as "editable" are skipped.
By default, editable dependencies are not skipped.
`state` is an `AuditState` to use for state callbacks.
"""
self._local = local
self._paths = paths
self._skip_editable = skip_editable
self.state = state
# NOTE: By default `pip_api` invokes `pip` through `sys.executable`, like so:
#
# {sys.executable} -m pip [args ...]
#
# This is the right decision 99% of the time, but it can result in unintuitive audits
# for users who have installed `pip-audit` globally but are trying to audit
# a loaded virtual environment, since `pip-audit`'s `sys.executable` will be the global
# Python and not the virtual environment's Python.
#
# To check for this, we check whether the Python that `pip_api` plans to use
# matches the active virtual environment's prefix. We do this instead of comparing
# against the $PATH-prioritized Python because that might be the same "effective"
# Python but with a different symlink (e.g. `<path>/python{,3,3.7}`). We *could*
# handle that case by resolving the symlinks, but that would then piece the
# virtual environment that we're attempting to detect.
effective_python = os.environ.get("PIPAPI_PYTHON_LOCATION", sys.executable)
venv_prefix = os.getenv("VIRTUAL_ENV")
if venv_prefix is not None and not effective_python.startswith(venv_prefix):
logger.warning(
f"pip-audit will run pip against {effective_python}, but you have "
f"a virtual environment loaded at {venv_prefix}. This may result in "
"unintuitive audits, since your local environment will not be audited. "
"You can forcefully override this behavior by setting PIPAPI_PYTHON_LOCATION "
"to the location of your virtual environment's Python interpreter."
)
if _PIP_VERSION < _MINIMUM_RELIABLE_PIP_VERSION:
logger.warning(
f"pip {_PIP_VERSION} is very old, and may not provide reliable "
"dependency information! You are STRONGLY encouraged to upgrade to a "
"newer version of pip."
)
def collect(self) -> Iterator[Dependency]:
"""
Collect all of the dependencies discovered by this `PipSource`.
Raises a `PipSourceError` on any errors.
"""
# The `pip list` call that underlies `pip_api` could fail for myriad reasons.
# We collect them all into a single well-defined error.
try:
for _, dist in pip_api.installed_distributions(
local=self._local, paths=list(self._paths)
).items():
dep: Dependency
if dist.editable and self._skip_editable:
dep = SkippedDependency(
name=dist.name, skip_reason="distribution marked as editable"
)
else:
try:
dep = ResolvedDependency(name=dist.name, version=Version(str(dist.version)))
self.state.update_state(f"Collecting {dep.name} ({dep.version})")
except InvalidVersion:
skip_reason = (
"Package has invalid version and could not be audited: "
f"{dist.name} ({dist.version})"
)
logger.debug(skip_reason)
dep = SkippedDependency(name=dist.name, skip_reason=skip_reason)
yield dep
except Exception as e:
raise PipSourceError("failed to list installed distributions") from e
def fix(self, fix_version: ResolvedFixVersion) -> None:
"""
Fixes a dependency version in this `PipSource`.
"""
self.state.update_state(
f"Fixing {fix_version.dep.name} ({fix_version.dep.version} => {fix_version.version})"
)
fix_cmd = [
sys.executable,
"-m",
"pip",
"install",
f"{fix_version.dep.canonical_name}=={fix_version.version}",
]
try:
subprocess.run(
fix_cmd,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
except subprocess.CalledProcessError as cpe:
raise PipFixError(
f"failed to upgrade dependency {fix_version.dep.name} to fix version "
f"{fix_version.version}"
) from cpe
class PipSourceError(DependencySourceError):
"""A `pip` specific `DependencySourceError`."""
pass
class PipFixError(DependencyFixError):
"""A `pip` specific `DependencyFixError`."""
pass

View File

@@ -0,0 +1,112 @@
"""
Collect dependencies from `pylock.toml` files.
"""
import logging
from collections.abc import Iterator
from pathlib import Path
import toml
from packaging.version import Version
from pip_audit._dependency_source import DependencyFixError, DependencySource, DependencySourceError
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency, ResolvedDependency
from pip_audit._service.interface import SkippedDependency
logger = logging.getLogger(__name__)
class PyLockSource(DependencySource):
"""
Wraps `pylock.*.toml` dependency collection as a dependency source.
"""
def __init__(self, filenames: list[Path]) -> None:
"""
Create a new `PyLockSource`.
`filenames` provides a list of `pylock.*.toml` files to parse.
"""
self._filenames = filenames
def collect(self) -> Iterator[Dependency]:
"""
Collect all of the dependencies discovered by this `PyLockSource`.
Raises a `PyLockSourceError` on any errors.
"""
for filename in self._filenames:
yield from self._collect_from_file(filename)
def _collect_from_file(self, filename: Path) -> Iterator[Dependency]:
"""
Collect dependencies from a single `pylock.*.toml` file.
Raises a `PyLockSourceError` on any errors.
"""
try:
pylock = toml.load(filename)
except toml.TomlDecodeError as e:
raise PyLockSourceError(f"{filename}: invalid TOML in lockfile") from e
lock_version = pylock.get("lock-version")
if not lock_version:
raise PyLockSourceError(f"{filename}: missing lock-version in lockfile")
lock_version = Version(lock_version)
if lock_version.major != 1:
raise PyLockSourceError(f"{filename}: lockfile version {lock_version} is not supported")
packages = pylock.get("packages")
if not packages:
raise PyLockSourceError(f"{filename}: missing packages in lockfile")
try:
yield from self._collect_from_packages(packages)
except PyLockSourceError as e:
raise PyLockSourceError(f"{filename}: {e}") from e
def _collect_from_packages(self, packages: list[dict]) -> Iterator[Dependency]:
"""
Collect dependencies from a list of packages.
Raises a `PyLockSourceError` on any errors.
"""
for idx, package in enumerate(packages):
name = package.get("name")
if not name:
raise PyLockSourceError(f"invalid package #{idx}: no name")
version = package.get("version")
if version:
yield ResolvedDependency(name, Version(version))
else:
# Versions are optional in PEP 751, e.g. for source tree specifiers.
# We mark these as skipped.
yield SkippedDependency(name, "no version specified")
def fix(self, fix_version: ResolvedFixVersion) -> None: # pragma: no cover
"""
Raises `NotImplementedError` if called.
We don't support fixing dependencies in lockfiles, since
lockfiles should be managed/updated by their packaging tool.
"""
raise NotImplementedError(
"lockfiles cannot be fixed directly; use your packaging tool to perform upgrades"
)
class PyLockSourceError(DependencySourceError):
"""A pylock-parsing specific `DependencySourceError`."""
pass
class PyLockFixError(DependencyFixError):
"""A pylock-fizing specific `DependencyFixError`."""
pass

View File

@@ -0,0 +1,159 @@
"""
Collect dependencies from `pyproject.toml` files.
"""
from __future__ import annotations
import logging
import os
from collections.abc import Iterator
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
import toml
from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet
from pip_audit._dependency_source import (
DependencyFixError,
DependencySource,
DependencySourceError,
)
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency, ResolvedDependency
from pip_audit._state import AuditState
from pip_audit._virtual_env import VirtualEnv, VirtualEnvError
logger = logging.getLogger(__name__)
class PyProjectSource(DependencySource):
"""
Wraps `pyproject.toml` dependency resolution as a dependency source.
"""
def __init__(
self,
filename: Path,
index_url: str | None = None,
extra_index_urls: list[str] = [],
state: AuditState = AuditState(),
) -> None:
"""
Create a new `PyProjectSource`.
`filename` provides a path to a `pyproject.toml` file
`index_url` is the base URL of the package index.
`extra_index_urls` are the extra URLs of package indexes.
`state` is an `AuditState` to use for state callbacks.
"""
self.filename = filename
self.state = state
def collect(self) -> Iterator[Dependency]:
"""
Collect all of the dependencies discovered by this `PyProjectSource`.
Raises a `PyProjectSourceError` on any errors.
"""
with self.filename.open("r") as f:
pyproject_data = toml.load(f)
project = pyproject_data.get("project")
if project is None:
raise PyProjectSourceError(
f"pyproject file {self.filename} does not contain `project` section"
)
deps = project.get("dependencies")
if deps is None:
# Projects without dependencies aren't an error case
logger.warning(
f"pyproject file {self.filename} does not contain `dependencies` list"
)
return
# NOTE(alex): This is probably due for a redesign. Since we're leaning on `pip` for
# dependency resolution now, we can think about doing `pip install <local-project-dir>`
# regardless of whether the project has a `pyproject.toml` or not. And if it doesn't
# have a `pyproject.toml`, we can raise an error if the user provides `--fix`.
with (
TemporaryDirectory() as ve_dir,
NamedTemporaryFile(dir=ve_dir, delete=False) as req_file,
):
# We use delete=False in creating the tempfile to allow it to be
# closed and opened multiple times within the context scope on
# windows, see GitHub issue #646.
# Write the dependencies to a temporary requirements file.
req_file.write(os.linesep.join(deps).encode())
req_file.flush()
# Try to install the generated requirements file.
ve = VirtualEnv(install_args=["-r", req_file.name], state=self.state)
try:
ve.create(ve_dir)
except VirtualEnvError as exc:
raise PyProjectSourceError(str(exc)) from exc
# Now query the installed packages.
for name, version in ve.installed_packages:
yield ResolvedDependency(name=name, version=version)
def fix(self, fix_version: ResolvedFixVersion) -> None:
"""
Fixes a dependency version for this `PyProjectSource`.
"""
with self.filename.open("r+") as f, NamedTemporaryFile(mode="r+", delete=False) as tmp:
pyproject_data = toml.load(f)
project = pyproject_data.get("project")
if project is None:
raise PyProjectFixError(
f"pyproject file {self.filename} does not contain `project` section"
)
deps = project.get("dependencies")
if deps is None:
# Projects without dependencies aren't an error case
logger.warning(
f"pyproject file {self.filename} does not contain `dependencies` list"
)
return
reqs = [Requirement(dep) for dep in deps]
for i in range(len(reqs)):
# When we find a requirement that matches the provided fix version, we need to edit
# the requirement's specifier and then write it back to the underlying TOML data.
req = reqs[i]
if (
req.name == fix_version.dep.name
and req.specifier.contains(fix_version.dep.version)
and not req.specifier.contains(fix_version.version)
):
req.specifier = SpecifierSet(f"=={fix_version.version}")
deps[i] = str(req)
assert req.marker is None or req.marker.evaluate()
# Now dump the new edited TOML to the temporary file.
toml.dump(pyproject_data, tmp)
# And replace the original `pyproject.toml` file.
os.replace(tmp.name, self.filename)
class PyProjectSourceError(DependencySourceError):
"""A `pyproject.toml` specific `DependencySourceError`."""
pass
class PyProjectFixError(DependencyFixError):
"""A `pyproject.toml` specific `DependencyFixError`."""
pass

View File

@@ -0,0 +1,371 @@
"""
Collect dependencies from one or more `requirements.txt`-formatted files.
"""
from __future__ import annotations
import logging
import re
import shutil
from collections.abc import Iterator
from contextlib import ExitStack
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import IO
from packaging.specifiers import SpecifierSet
from packaging.utils import canonicalize_name
from packaging.version import Version
from pip_requirements_parser import (
InstallRequirement,
InvalidRequirementLine,
RequirementsFile,
)
from pip_audit._dependency_source import (
DependencyFixError,
DependencySource,
DependencySourceError,
InvalidRequirementSpecifier,
)
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency
from pip_audit._service.interface import ResolvedDependency, SkippedDependency
from pip_audit._state import AuditState
from pip_audit._virtual_env import VirtualEnv, VirtualEnvError
logger = logging.getLogger(__name__)
PINNED_SPECIFIER_RE = re.compile(r"==(?P<version>.+?)$", re.VERBOSE)
class RequirementSource(DependencySource):
"""
Wraps `requirements.txt` dependency resolution as a dependency source.
"""
def __init__(
self,
filenames: list[Path],
*,
require_hashes: bool = False,
no_deps: bool = False,
disable_pip: bool = False,
skip_editable: bool = False,
index_url: str | None = None,
extra_index_urls: list[str] = [],
state: AuditState = AuditState(),
) -> None:
"""
Create a new `RequirementSource`.
`filenames` provides the list of filepaths to parse.
`require_hashes` controls the hash policy: if `True`, dependency collection
will fail unless all requirements include hashes.
`disable_pip` controls the dependency resolution policy: if `True`,
dependency resolution is not performed and the inputs are checked
and treated as "frozen".
`no_deps` controls whether dependency resolution can be disabled even without
hashed requirements (which implies a fully resolved requirements file): if `True`,
`disable_pip` is allowed without a hashed requirements file.
`skip_editable` controls whether requirements marked as "editable" are skipped.
By default, editable requirements are not skipped.
`index_url` is the base URL of the package index.
`extra_index_urls` are the extra URLs of package indexes.
`state` is an `AuditState` to use for state callbacks.
"""
self._filenames = filenames
self._require_hashes = require_hashes
self._no_deps = no_deps
self._disable_pip = disable_pip
self._skip_editable = skip_editable
self._index_url = index_url
self._extra_index_urls = extra_index_urls
self.state = state
self._dep_cache: dict[Path, set[Dependency]] = {}
def collect(self) -> Iterator[Dependency]:
"""
Collect all of the dependencies discovered by this `RequirementSource`.
Raises a `RequirementSourceError` on any errors.
"""
collect_files = []
tmp_files = []
try:
for filename in self._filenames:
# We need to handle process substitution inputs so we can invoke
# `pip-audit` like so:
#
# pip-audit -r <(echo 'something')
#
# Since `/dev/fd/<n>` inputs are unique to the parent process,
# we can't pass these file names to `pip` and expect `pip` to
# able to read them.
#
# In order to get around this, we're going to copy each input
# into a corresponding temporary file and then pass that set of
# files into `pip`.
if filename.is_fifo():
# Deliberately pass `delete=False` so that our temporary
# file doesn't get automatically deleted on close. We need
# to close it so that `pip` can use it however, we
# obviously want it to persist.
tmp_file = NamedTemporaryFile(mode="w", delete=False)
with filename.open("r") as f:
shutil.copyfileobj(f, tmp_file)
# Close the file since it's going to get re-opened by `pip`.
tmp_file.close()
filename = Path(tmp_file.name)
tmp_files.append(filename)
collect_files.append(filename)
# Now pass the list of filenames into the rest of our logic.
yield from self._collect_from_files(collect_files)
finally:
# Since we disabled automatically deletion for these temporary
# files, we need to manually delete them on the way out.
for t in tmp_files:
t.unlink()
def _collect_from_files(self, filenames: list[Path]) -> Iterator[Dependency]:
# Figure out whether we have a fully resolved set of dependencies.
reqs: list[InstallRequirement] = []
require_hashes: bool = self._require_hashes
for filename in filenames:
rf = RequirementsFile.from_file(filename)
if len(rf.invalid_lines) > 0:
invalid = rf.invalid_lines[0]
raise InvalidRequirementSpecifier(
f"requirement file {filename} contains invalid specifier at "
f"line {invalid.line_number}: {invalid.error_message}"
)
# If one or more requirements have a hash, this implies `--require-hashes`.
require_hashes = require_hashes or any(req.hash_options for req in rf.requirements)
reqs.extend(rf.requirements)
# If the user has supplied `--no-deps` or there are hashed requirements, we should assume
# that we have a fully resolved set of dependencies and we should waste time by invoking
# `pip`.
if self._disable_pip:
if not self._no_deps and not require_hashes:
raise RequirementSourceError(
"the --disable-pip flag can only be used with a hashed requirements files or "
"if the --no-deps flag has been provided"
)
yield from self._collect_preresolved_deps(iter(reqs), require_hashes)
return
ve_args = []
if self._require_hashes:
ve_args.append("--require-hashes")
for filename in filenames:
ve_args.extend(["-r", str(filename)])
# Try to install the supplied requirements files.
ve = VirtualEnv(ve_args, self._index_url, self._extra_index_urls, self.state)
try:
with TemporaryDirectory() as ve_dir:
ve.create(ve_dir)
except VirtualEnvError as exc:
raise RequirementSourceError(str(exc)) from exc
# Now query the installed packages.
for name, version in ve.installed_packages:
yield ResolvedDependency(name=name, version=version)
def fix(self, fix_version: ResolvedFixVersion) -> None:
"""
Fixes a dependency version for this `RequirementSource`.
"""
with ExitStack() as stack:
# Make temporary copies of the existing requirements files. If anything goes wrong, we
# want to copy them back into place and undo any partial application of the fix.
tmp_files: list[IO[str]] = [
stack.enter_context(NamedTemporaryFile(mode="r+")) for _ in self._filenames
]
for filename, tmp_file in zip(self._filenames, tmp_files):
with filename.open("r") as f:
shutil.copyfileobj(f, tmp_file)
try:
# Now fix the files inplace
for filename in self._filenames:
self.state.update_state(
f"Fixing dependency {fix_version.dep.name} ({fix_version.dep.version} => "
f"{fix_version.version})"
)
self._fix_file(filename, fix_version)
except Exception as e:
logger.warning(
f"encountered an exception while applying fixes, recovering original files: {e}"
)
self._recover_files(tmp_files)
raise e
def _fix_file(self, filename: Path, fix_version: ResolvedFixVersion) -> None:
# Reparse the requirements file. We want to rewrite each line to the new requirements file
# and only modify the lines that we're fixing.
#
# This time we're using the `RequirementsFile.parse` API instead of `Requirements.from_file`
# since we want to access each line sequentially in order to rewrite the file.
reqs = list(RequirementsFile.parse(filename=filename.as_posix()))
# Check ahead of time for anything invalid in the requirements file since we don't want to
# encounter this while writing out the file. Check for duplicate requirements and lines that
# failed to parse.
req_specifiers: dict[str, SpecifierSet] = dict()
for req in reqs:
if (
isinstance(req, InstallRequirement)
and (req.marker is None or req.marker.evaluate())
and req.req is not None
):
duplicate_req_specifier = req_specifiers.get(req.name)
if not duplicate_req_specifier:
req_specifiers[req.name] = req.specifier
elif duplicate_req_specifier != req.specifier:
raise RequirementFixError(
f"package {req.name} has duplicate requirements: {str(req)}"
)
elif isinstance(req, InvalidRequirementLine):
raise RequirementFixError(
f"requirement file {filename} has invalid requirement: {str(req)}"
)
# Now write out the new requirements file
with filename.open("w") as f:
found = False
for req in reqs:
if (
isinstance(req, InstallRequirement)
and canonicalize_name(req.name) == fix_version.dep.canonical_name
):
found = True
if req.specifier.contains(
fix_version.dep.version
) and not req.specifier.contains(fix_version.version):
req.req.specifier = SpecifierSet(f"=={fix_version.version}")
print(req.dumps(), file=f)
# The vulnerable dependency may not be explicitly listed in the requirements file if it
# is a subdependency of a requirement. In this case, we should explicitly add the fixed
# dependency into the requirements file.
#
# To know whether this is the case, we'll need to resolve dependencies if we haven't
# already in order to figure out whether this subdependency belongs to this file or
# another.
if not found:
logger.warning(
"added fixed subdependency explicitly to requirements file "
f"{filename}: {fix_version.dep.canonical_name}"
)
print(
" # pip-audit: subdependency explicitly fixed",
file=f,
)
print(f"{fix_version.dep.canonical_name}=={fix_version.version}", file=f)
def _recover_files(self, tmp_files: list[IO[str]]) -> None:
for filename, tmp_file in zip(self._filenames, tmp_files):
try:
tmp_file.seek(0)
with filename.open("w") as f:
shutil.copyfileobj(tmp_file, f)
except Exception as e:
# Not much we can do at this point since we're already handling an exception. Just
# log the error and try to recover the rest of the files.
logger.warning(f"encountered an exception during file recovery: {e}")
continue
def _collect_preresolved_deps(
self, reqs: Iterator[InstallRequirement], require_hashes: bool
) -> Iterator[Dependency]:
"""
Collect pre-resolved (pinned) dependencies.
"""
req_specifiers: dict[str, SpecifierSet] = dict()
for req in reqs:
if not req.hash_options and require_hashes:
raise RequirementSourceError(f"requirement {req.dumps()} does not contain a hash")
if req.req is None:
# PEP 508-style URL requirements don't have a pre-declared version, even
# when hashed; the `#egg=name==version` syntax is non-standard and not supported
# by `pip` itself.
#
# In this case, we can't audit the dependency so we should signal to the
# caller that we're skipping it.
yield SkippedDependency(
name=req.requirement_line.line,
skip_reason="could not deduce package version from URL requirement",
)
continue
if self._skip_editable and req.is_editable:
yield SkippedDependency(name=req.name, skip_reason="requirement marked as editable")
if req.marker is not None and not req.marker.evaluate():
# TODO(ww): Remove this `no cover` pragma once we're 3.10+.
# See: https://github.com/nedbat/coveragepy/issues/198
continue # pragma: no cover
duplicate_req_specifier = req_specifiers.get(req.name)
if not duplicate_req_specifier:
req_specifiers[req.name] = req.specifier
# We have a duplicate requirement for the same package
# but different specifiers, meaning a badly resolved requirements.txt
elif duplicate_req_specifier != req.specifier:
raise RequirementSourceError(
f"package {req.name} has duplicate requirements: {str(req)}"
)
else:
# We have a duplicate requirement for the same package and the specifier matches
# As they would return the same result from the audit, there no need to yield it a second time.
continue # pragma: no cover
# NOTE: URL dependencies cannot be pinned, so skipping them
# makes sense (under the same principle of skipping dependencies
# that can't be found on PyPI). This is also consistent with
# what `pip --no-deps` does (installs the URL dependency, but
# not any subdependencies).
if req.is_url:
yield SkippedDependency(
name=req.name,
skip_reason="URL requirements cannot be pinned to a specific package version",
)
elif not req.specifier:
raise RequirementSourceError(f"requirement {req.name} is not pinned: {str(req)}")
else:
pinned_specifier = PINNED_SPECIFIER_RE.match(str(req.specifier))
if pinned_specifier is None:
raise RequirementSourceError(
f"requirement {req.name} is not pinned to an exact version: {str(req)}"
)
yield ResolvedDependency(req.name, Version(pinned_specifier.group("version")))
class RequirementSourceError(DependencySourceError):
"""A requirements-parsing specific `DependencySourceError`."""
pass
class RequirementFixError(DependencyFixError):
"""A requirements-fixing specific `DependencyFixError`."""
pass

View File

@@ -0,0 +1,126 @@
"""
Functionality for resolving fixed versions of dependencies.
"""
from __future__ import annotations
import logging
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any, cast
from packaging.version import Version
from pip_audit._service import (
Dependency,
ResolvedDependency,
VulnerabilityResult,
VulnerabilityService,
)
from pip_audit._state import AuditState
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class FixVersion:
"""
Represents an abstract dependency fix version.
This class cannot be constructed directly.
"""
dep: ResolvedDependency
def __init__(self, *_args: Any, **_kwargs: Any) -> None: # pragma: no cover
"""
A stub constructor that always fails.
"""
raise NotImplementedError
def is_skipped(self) -> bool:
"""
Check whether the `FixVersion` was unable to be resolved.
"""
return self.__class__ is SkippedFixVersion
@dataclass(frozen=True)
class ResolvedFixVersion(FixVersion):
"""
Represents a resolved fix version.
"""
version: Version
@dataclass(frozen=True)
class SkippedFixVersion(FixVersion):
"""
Represents a fix version that was unable to be resolved and therefore, skipped.
"""
skip_reason: str
def resolve_fix_versions(
service: VulnerabilityService,
result: dict[Dependency, list[VulnerabilityResult]],
state: AuditState = AuditState(),
) -> Iterator[FixVersion]:
"""
Resolves a mapping of dependencies to known vulnerabilities to a series of fix versions without
known vulnerabilities.
"""
for dep, vulns in result.items():
if dep.is_skipped():
continue
if not vulns:
continue
dep = cast(ResolvedDependency, dep)
try:
version = _resolve_fix_version(service, dep, vulns, state)
yield ResolvedFixVersion(dep, version)
except FixResolutionImpossible as fri:
skip_reason = str(fri)
logger.debug(skip_reason)
yield SkippedFixVersion(dep, skip_reason)
def _resolve_fix_version(
service: VulnerabilityService,
dep: ResolvedDependency,
vulns: list[VulnerabilityResult],
state: AuditState,
) -> Version:
# We need to upgrade to a fix version that satisfies all vulnerability results
#
# However, whenever we upgrade a dependency, we run the risk of introducing new vulnerabilities
# so we need to run this in a loop and continue polling the vulnerability service on each
# prospective resolved fix version
current_version = dep.version
current_vulns = vulns
while current_vulns:
state.update_state(f"Resolving fix version for {dep.name}, checking {current_version}")
def get_earliest_fix_version(d: ResolvedDependency, v: VulnerabilityResult) -> Version:
for fix_version in v.fix_versions:
if fix_version > current_version:
return fix_version
raise FixResolutionImpossible(
f"failed to fix dependency {dep.name} ({dep.version}), unable to find fix version "
f"for vulnerability {v.id}"
)
# We want to retrieve a version that potentially fixes all vulnerabilities
current_version = max([get_earliest_fix_version(dep, v) for v in current_vulns])
_, current_vulns = service.query(ResolvedDependency(dep.name, current_version))
return current_version
class FixResolutionImpossible(Exception):
"""
Raised when `resolve_fix_versions` fails to find a fix version without known vulnerabilities
"""
pass

View File

@@ -0,0 +1,17 @@
"""
Output format interfaces and implementations for `pip-audit`.
"""
from .columns import ColumnsFormat
from .cyclonedx import CycloneDxFormat
from .interface import VulnerabilityFormat
from .json import JsonFormat
from .markdown import MarkdownFormat
__all__ = [
"ColumnsFormat",
"CycloneDxFormat",
"VulnerabilityFormat",
"JsonFormat",
"MarkdownFormat",
]

View File

@@ -0,0 +1,167 @@
"""
Functionality for formatting vulnerability results as a set of human-readable columns.
"""
from __future__ import annotations
from collections.abc import Iterable
from itertools import zip_longest
from typing import Any, cast
from packaging.version import Version
import pip_audit._fix as fix
import pip_audit._service as service
from .interface import VulnerabilityFormat
def tabulate(rows: Iterable[Iterable[Any]]) -> tuple[list[str], list[int]]:
"""Return a list of formatted rows and a list of column sizes.
For example::
>>> tabulate([['foobar', 2000], [0xdeadbeef]])
(['foobar 2000', '3735928559'], [10, 4])
"""
rows = [tuple(map(str, row)) for row in rows]
sizes = [max(map(len, col)) for col in zip_longest(*rows, fillvalue="")]
table = [" ".join(map(str.ljust, row, sizes)).rstrip() for row in rows]
return table, sizes
class ColumnsFormat(VulnerabilityFormat):
"""
An implementation of `VulnerabilityFormat` that formats vulnerability results as a set of
columns.
"""
def __init__(self, output_desc: bool, output_aliases: bool):
"""
Create a new `ColumnFormat`.
`output_desc` is a flag to determine whether descriptions for each vulnerability should be
included in the output as they can be quite long and make the output difficult to read.
`output_aliases` is a flag to determine whether aliases (such as CVEs) for each
vulnerability should be included in the output.
"""
self.output_desc = output_desc
self.output_aliases = output_aliases
@property
def is_manifest(self) -> bool:
"""
See `VulnerabilityFormat.is_manifest`.
"""
return False
def format(
self,
result: dict[service.Dependency, list[service.VulnerabilityResult]],
fixes: list[fix.FixVersion],
) -> str:
"""
Returns a column formatted string for a given mapping of dependencies to vulnerability
results.
See `VulnerabilityFormat.format`.
"""
vuln_data: list[list[Any]] = []
header = ["Name", "Version", "ID", "Fix Versions"]
if fixes:
header.append("Applied Fix")
if self.output_aliases:
header.append("Aliases")
if self.output_desc:
header.append("Description")
vuln_data.append(header)
for dep, vulns in result.items():
if dep.is_skipped():
continue
dep = cast(service.ResolvedDependency, dep)
applied_fix = next((f for f in fixes if f.dep == dep), None)
for vuln in vulns:
vuln_data.append(self._format_vuln(dep, vuln, applied_fix))
columns_string = ""
# If it's just a header, don't bother adding it to the output
if len(vuln_data) > 1:
vuln_strings, sizes = tabulate(vuln_data)
# Create and add a separator.
if len(vuln_data) > 0:
vuln_strings.insert(1, " ".join(map(lambda x: "-" * x, sizes)))
for row in vuln_strings:
if columns_string:
columns_string += "\n"
columns_string += row
# Now display the skipped dependencies
skip_data: list[list[Any]] = []
skip_header = ["Name", "Skip Reason"]
skip_data.append(skip_header)
for dep, _ in result.items():
if dep.is_skipped():
dep = cast(service.SkippedDependency, dep)
skip_data.append(self._format_skipped_dep(dep))
# If we only have the header, that means that we haven't skipped any dependencies
# In that case, don't bother printing the header
if len(skip_data) <= 1:
return columns_string
skip_strings, sizes = tabulate(skip_data)
# Create separator for skipped dependencies columns
skip_strings.insert(1, " ".join(map(lambda x: "-" * x, sizes)))
for row in skip_strings:
if columns_string:
columns_string += "\n"
columns_string += row
return columns_string
def _format_vuln(
self,
dep: service.ResolvedDependency,
vuln: service.VulnerabilityResult,
applied_fix: fix.FixVersion | None,
) -> list[Any]:
vuln_data = [
dep.canonical_name,
dep.version,
vuln.id,
self._format_fix_versions(vuln.fix_versions),
]
if applied_fix is not None:
vuln_data.append(self._format_applied_fix(applied_fix))
if self.output_aliases:
vuln_data.append(", ".join(vuln.aliases))
if self.output_desc:
vuln_data.append(vuln.description)
return vuln_data
def _format_fix_versions(self, fix_versions: list[Version]) -> str:
return ",".join([str(version) for version in fix_versions])
def _format_skipped_dep(self, dep: service.SkippedDependency) -> list[Any]:
return [
dep.canonical_name,
dep.skip_reason,
]
def _format_applied_fix(self, applied_fix: fix.FixVersion) -> str:
if applied_fix.is_skipped():
applied_fix = cast(fix.SkippedFixVersion, applied_fix)
return (
f"Failed to fix {applied_fix.dep.canonical_name} ({applied_fix.dep.version}): "
f"{applied_fix.skip_reason}"
)
applied_fix = cast(fix.ResolvedFixVersion, applied_fix)
return (
f"Successfully upgraded {applied_fix.dep.canonical_name} ({applied_fix.dep.version} "
f"=> {applied_fix.version})"
)

View File

@@ -0,0 +1,100 @@
"""
Functionality for formatting vulnerability results using the CycloneDX SBOM format.
"""
from __future__ import annotations
import enum
import logging
from typing import cast
from cyclonedx import output
from cyclonedx.model.bom import Bom
from cyclonedx.model.component import Component
from cyclonedx.model.vulnerability import Vulnerability
import pip_audit._fix as fix
import pip_audit._service as service
from .interface import VulnerabilityFormat
logger = logging.getLogger(__name__)
def _pip_audit_result_to_bom(
result: dict[service.Dependency, list[service.VulnerabilityResult]],
) -> Bom:
vulnerabilities = []
components = []
for dep, vulns in result.items():
# TODO(alex): Is there anything interesting we can do with skipped dependencies in
# the CycloneDX format?
if dep.is_skipped():
continue
dep = cast(service.ResolvedDependency, dep)
c = Component(name=dep.name, version=str(dep.version))
for vuln in vulns:
vulnerabilities.append(
Vulnerability(id=vuln.id, description=vuln.description, recommendation="Upgrade")
)
components.append(c)
return Bom(components=components, vulnerabilities=vulnerabilities)
class CycloneDxFormat(VulnerabilityFormat):
"""
An implementation of `VulnerabilityFormat` that formats vulnerability results using CycloneDX.
The container format used by CycloneDX can be additionally configured.
"""
@enum.unique
class InnerFormat(enum.Enum):
"""
Valid container formats for CycloneDX.
"""
Json = output.OutputFormat.JSON
Xml = output.OutputFormat.XML
def __init__(self, inner_format: CycloneDxFormat.InnerFormat):
"""
Create a new `CycloneDxFormat`.
`inner_format` determines the container format used by CycloneDX.
"""
self._inner_format = inner_format
@property
def is_manifest(self) -> bool:
"""
See `VulnerabilityFormat.is_manifest`.
"""
return True
def format(
self,
result: dict[service.Dependency, list[service.VulnerabilityResult]],
fixes: list[fix.FixVersion],
) -> str:
"""
Returns a CycloneDX formatted string for a given mapping of dependencies to vulnerability
results.
See `VulnerabilityFormat.format`.
"""
if fixes:
logger.warning("--fix output is unsupported by CycloneDX formats")
bom = _pip_audit_result_to_bom(result)
formatter = output.make_outputter(
bom=bom,
output_format=self._inner_format.value,
schema_version=output.SchemaVersion.V1_4,
)
return formatter.output_as_string()

View File

@@ -0,0 +1,40 @@
"""
Interfaces for formatting vulnerability results into a string representation.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import pip_audit._fix as fix
import pip_audit._service as service
class VulnerabilityFormat(ABC):
"""
Represents an abstract string representation for vulnerability results.
"""
@property
@abstractmethod
def is_manifest(self) -> bool: # pragma: no cover
"""
Is this format a "manifest" format, i.e. one that prints a summary
of all results?
Manifest formats are always rendered emitted unconditionally, even
if the audit results contain nothing out of the ordinary
(no vulnerabilities, skips, or fixes).
"""
raise NotImplementedError
@abstractmethod
def format(
self,
result: dict[service.Dependency, list[service.VulnerabilityResult]],
fixes: list[fix.FixVersion],
) -> str: # pragma: no cover
"""
Convert a mapping of dependencies to vulnerabilities into a string.
"""
raise NotImplementedError

View File

@@ -0,0 +1,105 @@
"""
Functionality for formatting vulnerability results as an array of JSON objects.
"""
from __future__ import annotations
import json
from typing import Any, cast
import pip_audit._fix as fix
import pip_audit._service as service
from .interface import VulnerabilityFormat
class JsonFormat(VulnerabilityFormat):
"""
An implementation of `VulnerabilityFormat` that formats vulnerability results as an array of
JSON objects.
"""
def __init__(self, output_desc: bool, output_aliases: bool):
"""
Create a new `JsonFormat`.
`output_desc` is a flag to determine whether descriptions for each vulnerability should be
included in the output as they can be quite long and make the output difficult to read.
`output_aliases` is a flag to determine whether aliases (such as CVEs) for each
vulnerability should be included in the output.
"""
self.output_desc = output_desc
self.output_aliases = output_aliases
@property
def is_manifest(self) -> bool:
"""
See `VulnerabilityFormat.is_manifest`.
"""
return True
def format(
self,
result: dict[service.Dependency, list[service.VulnerabilityResult]],
fixes: list[fix.FixVersion],
) -> str:
"""
Returns a JSON formatted string for a given mapping of dependencies to vulnerability
results.
See `VulnerabilityFormat.format`.
"""
output_json = {}
dep_json = []
for dep, vulns in result.items():
dep_json.append(self._format_dep(dep, vulns))
output_json["dependencies"] = dep_json
fix_json = []
for f in fixes:
fix_json.append(self._format_fix(f))
output_json["fixes"] = fix_json
return json.dumps(output_json)
def _format_dep(
self, dep: service.Dependency, vulns: list[service.VulnerabilityResult]
) -> dict[str, Any]:
if dep.is_skipped():
dep = cast(service.SkippedDependency, dep)
return {
"name": dep.canonical_name,
"skip_reason": dep.skip_reason,
}
dep = cast(service.ResolvedDependency, dep)
return {
"name": dep.canonical_name,
"version": str(dep.version),
"vulns": [self._format_vuln(vuln) for vuln in vulns],
}
def _format_vuln(self, vuln: service.VulnerabilityResult) -> dict[str, Any]:
vuln_json = {
"id": vuln.id,
"fix_versions": [str(version) for version in vuln.fix_versions],
}
if self.output_aliases:
vuln_json["aliases"] = list(vuln.aliases)
if self.output_desc:
vuln_json["description"] = vuln.description
return vuln_json
def _format_fix(self, fix_version: fix.FixVersion) -> dict[str, Any]:
if fix_version.is_skipped():
fix_version = cast(fix.SkippedFixVersion, fix_version)
return {
"name": fix_version.dep.canonical_name,
"version": str(fix_version.dep.version),
"skip_reason": fix_version.skip_reason,
}
fix_version = cast(fix.ResolvedFixVersion, fix_version)
return {
"name": fix_version.dep.canonical_name,
"old_version": str(fix_version.dep.version),
"new_version": str(fix_version.version),
}

View File

@@ -0,0 +1,156 @@
"""
Functionality for formatting vulnerability results as a Markdown table.
"""
from __future__ import annotations
from textwrap import dedent
from typing import cast
from packaging.version import Version
import pip_audit._fix as fix
import pip_audit._service as service
from .interface import VulnerabilityFormat
class MarkdownFormat(VulnerabilityFormat):
"""
An implementation of `VulnerabilityFormat` that formats vulnerability results as a set of
Markdown tables.
"""
def __init__(self, output_desc: bool, output_aliases: bool) -> None:
"""
Create a new `MarkdownFormat`.
`output_desc` is a flag to determine whether descriptions for each vulnerability should be
included in the output as they can be quite long and make the output difficult to read.
`output_aliases` is a flag to determine whether aliases (such as CVEs) for each
vulnerability should be included in the output.
"""
self.output_desc = output_desc
self.output_aliases = output_aliases
@property
def is_manifest(self) -> bool:
"""
See `VulnerabilityFormat.is_manifest`.
"""
return False
def format(
self,
result: dict[service.Dependency, list[service.VulnerabilityResult]],
fixes: list[fix.FixVersion],
) -> str:
"""
Returns a Markdown formatted string representing a set of vulnerability results and applied
fixes.
"""
output = self._format_vuln_results(result, fixes)
skipped_deps_output = self._format_skipped_deps(result)
if skipped_deps_output:
# If we wrote the results table already, we need to add some line breaks to ensure that
# the skipped dependency table renders correctly.
if output:
output += "\n"
output += skipped_deps_output
return output
def _format_vuln_results(
self,
result: dict[service.Dependency, list[service.VulnerabilityResult]],
fixes: list[fix.FixVersion],
) -> str:
header = "Name | Version | ID | Fix Versions"
border = "--- | --- | --- | ---"
if fixes:
header += " | Applied Fix"
border += " | ---"
if self.output_aliases:
header += " | Aliases"
border += " | ---"
if self.output_desc:
header += " | Description"
border += " | ---"
vuln_rows: list[str] = []
for dep, vulns in result.items():
if dep.is_skipped():
continue
dep = cast(service.ResolvedDependency, dep)
applied_fix = next((f for f in fixes if f.dep == dep), None)
for vuln in vulns:
vuln_rows.append(self._format_vuln(dep, vuln, applied_fix))
if not vuln_rows:
return ""
return dedent(
f"""
{header}
{border}
"""
) + "\n".join(vuln_rows)
def _format_vuln(
self,
dep: service.ResolvedDependency,
vuln: service.VulnerabilityResult,
applied_fix: fix.FixVersion | None,
) -> str:
vuln_text = (
f"{dep.canonical_name} | {dep.version} | {vuln.id} | "
f"{self._format_fix_versions(vuln.fix_versions)}"
)
if applied_fix is not None:
vuln_text += f" | {self._format_applied_fix(applied_fix)}"
if self.output_aliases:
vuln_text += f" | {', '.join(vuln.aliases)}"
if self.output_desc:
vuln_text += f" | {vuln.description}"
return vuln_text
def _format_fix_versions(self, fix_versions: list[Version]) -> str:
return ",".join([str(version) for version in fix_versions])
def _format_applied_fix(self, applied_fix: fix.FixVersion) -> str:
if applied_fix.is_skipped():
applied_fix = cast(fix.SkippedFixVersion, applied_fix)
return (
f"Failed to fix {applied_fix.dep.canonical_name} ({applied_fix.dep.version}): "
f"{applied_fix.skip_reason}"
)
applied_fix = cast(fix.ResolvedFixVersion, applied_fix)
return (
f"Successfully upgraded {applied_fix.dep.canonical_name} ({applied_fix.dep.version} "
f"=> {applied_fix.version})"
)
def _format_skipped_deps(
self, result: dict[service.Dependency, list[service.VulnerabilityResult]]
) -> str:
header = "Name | Skip Reason"
border = "--- | ---"
skipped_dep_rows: list[str] = []
for dep, _ in result.items():
if dep.is_skipped():
dep = cast(service.SkippedDependency, dep)
skipped_dep_rows.append(self._format_skipped_dep(dep))
if not skipped_dep_rows:
return ""
return dedent(
f"""
{header}
{border}
"""
) + "\n".join(skipped_dep_rows)
def _format_skipped_dep(self, dep: service.SkippedDependency) -> str:
return f"{dep.name} | {dep.skip_reason}"

View File

@@ -0,0 +1,27 @@
"""
Vulnerability service interfaces and implementations for `pip-audit`.
"""
from .interface import (
ConnectionError,
Dependency,
ResolvedDependency,
ServiceError,
SkippedDependency,
VulnerabilityResult,
VulnerabilityService,
)
from .osv import OsvService
from .pypi import PyPIService
__all__ = [
"ConnectionError",
"Dependency",
"ResolvedDependency",
"ServiceError",
"SkippedDependency",
"VulnerabilityResult",
"VulnerabilityService",
"OsvService",
"PyPIService",
]

View File

@@ -0,0 +1,190 @@
"""
Interfaces for interacting with vulnerability services, i.e. sources
of vulnerability information for fully resolved Python packages.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass, replace
from datetime import datetime
from typing import Any, NewType
from packaging.utils import canonicalize_name
from packaging.version import Version
VulnerabilityID = NewType("VulnerabilityID", str)
@dataclass(frozen=True)
class Dependency:
"""
Represents an abstract Python package.
This class cannot be constructed directly.
"""
name: str
"""
The package's **uncanonicalized** name.
Use the `canonicalized_name` property when a canonicalized form is necessary.
"""
def __init__(self, *_args: Any, **_kwargs: Any) -> None:
"""
A stub constructor that always fails.
"""
raise NotImplementedError
# TODO(ww): Use functools.cached_property when supported Python is 3.8+.
@property
def canonical_name(self) -> str:
"""
The `Dependency`'s PEP-503 canonicalized name.
"""
return canonicalize_name(self.name)
def is_skipped(self) -> bool:
"""
Check whether the `Dependency` was skipped by the audit.
"""
return self.__class__ is SkippedDependency
@dataclass(frozen=True)
class ResolvedDependency(Dependency):
"""
Represents a fully resolved Python package.
"""
version: Version
@dataclass(frozen=True)
class SkippedDependency(Dependency):
"""
Represents a Python package that was unable to be audited and therefore, skipped.
"""
skip_reason: str
@dataclass(frozen=True)
class VulnerabilityResult:
"""
Represents a "result" from a vulnerability service, indicating a vulnerability
in some Python package.
"""
id: VulnerabilityID
"""
A service-provided identifier for the vulnerability.
"""
description: str
"""
A human-readable description of the vulnerability.
"""
fix_versions: list[Version]
"""
A list of versions that can be upgraded to that resolve the vulnerability.
"""
aliases: set[str]
"""
A set of aliases (alternative identifiers) for this result.
"""
published: datetime | None = None
"""
When the vulnerability was first published.
"""
def alias_of(self, other: VulnerabilityResult) -> bool:
"""
Returns whether this result is an "alias" of another result.
Two results are said to be aliases if their respective sets of
`{id, *aliases}` intersect at all. A result is therefore its own alias.
"""
return bool((self.aliases | {self.id}).intersection(other.aliases | {other.id}))
def merge_aliases(self, other: VulnerabilityResult) -> VulnerabilityResult:
"""
Merge `other`'s aliases into this result, returning a new result.
"""
# Our own ID should never occur in the alias set.
aliases = self.aliases | other.aliases - {self.id}
return replace(self, aliases=aliases)
def has_any_id(self, ids: set[str]) -> bool:
"""
Returns whether ids intersects with {id} | aliases.
"""
return bool(ids & (self.aliases | {self.id}))
class VulnerabilityService(ABC):
"""
Represents an abstract provider of Python package vulnerability information.
"""
@abstractmethod
def query(
self, spec: Dependency
) -> tuple[Dependency, list[VulnerabilityResult]]: # pragma: no cover
"""
Query the `VulnerabilityService` for information about the given `Dependency`,
returning a list of `VulnerabilityResult`.
"""
raise NotImplementedError
def query_all(
self, specs: Iterator[Dependency]
) -> Iterator[tuple[Dependency, list[VulnerabilityResult]]]:
"""
Query the vulnerability service for information on multiple dependencies.
`VulnerabilityService` implementations can override this implementation with
a more optimized one, if they support batched or bulk requests.
"""
for spec in specs:
yield self.query(spec)
@staticmethod
def _parse_rfc3339(dt: str | None) -> datetime | None:
if dt is None:
return None
# NOTE: OSV's schema says timestamps are RFC3339 but strptime
# has no way to indicate an optional field (like `%f`), so
# we have to try-and-retry with the two different expected formats.
# See: https://github.com/google/osv.dev/issues/857
try:
return datetime.strptime(dt, "%Y-%m-%dT%H:%M:%S.%fZ")
except ValueError:
return datetime.strptime(dt, "%Y-%m-%dT%H:%M:%SZ")
class ServiceError(Exception):
"""
Raised when a `VulnerabilityService` fails, for any reason.
Concrete implementations of `VulnerabilityService` are expected to subclass
this exception to provide more context.
"""
pass
class ConnectionError(ServiceError):
"""
A specialization of `ServiceError` specifically for cases where the
vulnerability service is unreachable or offline.
"""
pass

View File

@@ -0,0 +1,155 @@
"""
Functionality for using the [OSV](https://osv.dev/) API as a `VulnerabilityService`.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any, cast
import requests
from packaging.version import Version
from pip_audit._cache import caching_session
from pip_audit._service.interface import (
ConnectionError,
Dependency,
ResolvedDependency,
ServiceError,
VulnerabilityResult,
VulnerabilityService,
)
logger = logging.getLogger(__name__)
class OsvService(VulnerabilityService):
"""
An implementation of `VulnerabilityService` that uses OSV to provide Python
package vulnerability information.
"""
def __init__(self, cache_dir: Path | None = None, timeout: int | None = None):
"""
Create a new `OsvService`.
`cache_dir` is an optional cache directory to use, for caching and reusing OSV API
requests. If `None`, `pip-audit` will use its own internal caching directory.
`timeout` is an optional argument to control how many seconds the component should wait for
responses to network requests.
"""
self.session = caching_session(cache_dir, use_pip=False)
self.timeout = timeout
def query(self, spec: Dependency) -> tuple[Dependency, list[VulnerabilityResult]]:
"""
Queries OSV for the given `Dependency` specification.
See `VulnerabilityService.query`.
"""
if spec.is_skipped():
return spec, []
spec = cast(ResolvedDependency, spec)
url = "https://api.osv.dev/v1/query"
query = {
"package": {"name": spec.canonical_name, "ecosystem": "PyPI"},
"version": str(spec.version),
}
try:
response: requests.Response = self.session.post(
url=url,
data=json.dumps(query),
timeout=self.timeout,
)
response.raise_for_status()
except requests.ConnectTimeout:
raise ConnectionError("Could not connect to OSV's vulnerability feed")
except requests.HTTPError as http_error:
raise ServiceError from http_error
# If the response is empty, that means that the package/version pair doesn't have any
# associated vulnerabilities
#
# In that case, return an empty list
results: list[VulnerabilityResult] = []
response_json = response.json()
if not response_json:
return spec, results
vuln: dict[str, Any]
for vuln in response_json["vulns"]:
# Sanity check: only the v1 schema is specified at the moment,
# and the code below probably won't work with future incompatible
# schemas without additional changes.
# The absence of a schema is treated as 1.0.0, per the OSV spec.
schema_version = Version(vuln.get("schema_version", "1.0.0"))
if schema_version.major != 1:
logger.warning(f"Unsupported OSV schema version: {schema_version}")
continue
id = vuln["id"]
# If the vulnerability has been withdrawn, we skip it entirely.
withdrawn_at = vuln.get("withdrawn")
if withdrawn_at is not None:
logger.debug(f"OSV vuln entry '{id}' marked as withdrawn at {withdrawn_at}")
continue
# The summary is intended to be shorter, so we prefer it over
# details, if present. However, neither is required.
description = vuln.get("summary")
if description is None:
description = vuln.get("details")
if description is None:
description = "N/A"
# The "summary" field should be a single line, but "details" might
# be multiple (Markdown-formatted) lines. So, we normalize our
# description into a single line (and potentially break the Markdown
# formatting in the process).
description = description.replace("\n", " ")
# OSV doesn't mandate this field either. There's very little we
# can do without it, so we skip any results that are missing it.
affecteds = vuln.get("affected")
if affecteds is None:
logger.warning(f"OSV vuln entry '{id}' is missing 'affected' list")
continue
fix_versions: list[Version] = []
for affected in affecteds:
pkg = affected["package"]
# We only care about PyPI versions
if pkg["name"] == spec.canonical_name and pkg["ecosystem"] == "PyPI":
for ranges in affected["ranges"]:
if ranges["type"] == "ECOSYSTEM":
# Filter out non-fix versions
fix_version_strs = [
version["fixed"]
for version in ranges["events"]
if "fixed" in version
]
# Convert them to version objects
fix_versions = [
Version(version_str) for version_str in fix_version_strs
]
break
# The ranges aren't guaranteed to come in chronological order
fix_versions.sort()
results.append(
VulnerabilityResult(
id=id,
description=description,
fix_versions=fix_versions,
aliases=set(vuln.get("aliases", [])),
published=self._parse_rfc3339(vuln.get("published")),
)
)
return spec, results

View File

@@ -0,0 +1,135 @@
"""
Functionality for using the [PyPI](https://warehouse.pypa.io/api-reference/json.html)
API as a `VulnerabilityService`.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import cast
import requests
from packaging.version import InvalidVersion, Version
from pip_audit._cache import caching_session
from pip_audit._service.interface import (
ConnectionError,
Dependency,
ResolvedDependency,
ServiceError,
SkippedDependency,
VulnerabilityResult,
VulnerabilityService,
)
logger = logging.getLogger(__name__)
class PyPIService(VulnerabilityService):
"""
An implementation of `VulnerabilityService` that uses PyPI to provide Python
package vulnerability information.
"""
def __init__(self, cache_dir: Path | None = None, timeout: int | None = None) -> None:
"""
Create a new `PyPIService`.
`cache_dir` is an optional cache directory to use, for caching and reusing PyPI API
requests. If `None`, `pip-audit` will attempt to use `pip`'s cache directory before falling
back on its own default cache directory.
`timeout` is an optional argument to control how many seconds the component should wait for
responses to network requests.
"""
self.session = caching_session(cache_dir)
self.timeout = timeout
def query(self, spec: Dependency) -> tuple[Dependency, list[VulnerabilityResult]]:
"""
Queries PyPI for the given `Dependency` specification.
See `VulnerabilityService.query`.
"""
if spec.is_skipped():
return spec, []
spec = cast(ResolvedDependency, spec)
url = f"https://pypi.org/pypi/{spec.canonical_name}/{str(spec.version)}/json"
try:
response: requests.Response = self.session.get(url=url, timeout=self.timeout)
response.raise_for_status()
except requests.TooManyRedirects:
# This should never happen with a healthy PyPI instance, but might
# happen during an outage or network event.
# Ref 2022-06-10: https://status.python.org/incidents/lgpr13fy71bk
raise ConnectionError("PyPI is not redirecting properly")
except requests.ConnectTimeout:
# Apart from a normal network outage, this can happen for two main
# reasons:
# 1. PyPI's APIs are offline
# 2. The user is behind a firewall or corporate network that blocks
# PyPI (and they're probably using custom indices)
raise ConnectionError("Could not connect to PyPI's vulnerability feed")
except requests.HTTPError as http_error:
if response.status_code == 404:
skip_reason = (
"Dependency not found on PyPI and could not be audited: "
f"{spec.canonical_name} ({spec.version})"
)
logger.debug(skip_reason)
return SkippedDependency(name=spec.name, skip_reason=skip_reason), []
raise ServiceError from http_error
response_json = response.json()
results: list[VulnerabilityResult] = []
vulns = response_json.get("vulnerabilities")
# No `vulnerabilities` key means that there are no vulnerabilities for any version
if vulns is None:
return spec, results
for v in vulns:
id = v["id"]
# If the vulnerability has been withdrawn, we skip it entirely.
withdrawn_at = v.get("withdrawn")
if withdrawn_at is not None:
logger.debug(f"PyPI vuln entry '{id}' marked as withdrawn at {withdrawn_at}")
continue
# Put together the fix versions list
try:
fix_versions = [Version(fixed_in) for fixed_in in v["fixed_in"]]
except InvalidVersion as iv:
raise ServiceError(f"Received malformed version from PyPI: {v['fixed_in']}") from iv
# The ranges aren't guaranteed to come in chronological order
fix_versions.sort()
description = v.get("summary")
if description is None:
description = v.get("details")
if description is None:
description = "N/A"
# The "summary" field should be a single line, but "details" might
# be multiple (Markdown-formatted) lines. So, we normalize our
# description into a single line (and potentially break the Markdown
# formatting in the process).
description = description.replace("\n", " ")
results.append(
VulnerabilityResult(
id=id,
description=description,
fix_versions=fix_versions,
aliases=set(v["aliases"]),
published=self._parse_rfc3339(v.get("published")),
)
)
return spec, results

View File

@@ -0,0 +1,274 @@
"""
Interfaces for for propagating feedback from the API to provide responsive progress indicators as
well as a progress spinner implementation for use with CLI applications.
"""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from logging.handlers import MemoryHandler
from typing import Any
from rich.align import StyleType
from rich.console import Console, Group, RenderableType
from rich.live import Live
from rich.panel import Panel
from rich.status import Spinner
class AuditState:
"""
An object that handles abstract "updates" to `pip-audit`'s state.
Non-UI consumers of `pip-audit` (via `pip_audit`) should have no need for
this class, and can leave it as a default construction in whatever signatures
it appears in. Its primary use is internal and UI-specific: it exists solely
to give the CLI enough state for a responsive progress indicator during
user requests.
"""
def __init__(self, *, members: Sequence[_StateActor] = []):
"""
Create a new `AuditState` with the given member list.
"""
self._members = members
def update_state(self, message: str, logs: str | None = None) -> None:
"""
Called whenever `pip_audit`'s internal state changes in a way that's meaningful to
expose to a user.
`message` is the message to present to the user.
"""
for member in self._members:
member.update_state(message, logs)
def initialize(self) -> None:
"""
Called when `pip-audit`'s state is initializing.
"""
for member in self._members:
member.initialize()
def finalize(self) -> None:
"""
Called when `pip_audit`'s state is "done" changing.
"""
for member in self._members:
member.finalize()
def __enter__(self) -> AuditState: # pragma: no cover
"""
Create an instance of the `pip-audit` state for usage within a `with` statement.
"""
self.initialize()
return self
def __exit__(
self, _exc_type: Any, _exc_value: Any, _exc_traceback: Any
) -> None: # pragma: no cover
"""
Helper to ensure `finalize` gets called when the `pip-audit` state falls out of scope of a
`with` statement.
"""
self.finalize()
class _StateActor(ABC):
@abstractmethod
def update_state(self, message: str, logs: str | None = None) -> None:
raise NotImplementedError # pragma: no cover
@abstractmethod
def initialize(self) -> None:
"""
Called when `pip-audit`'s state is initializing. Implementors should
override this to do nothing if their state management requires no
initialization step.
"""
raise NotImplementedError # pragma: no cover
@abstractmethod
def finalize(self) -> None:
"""
Called when the overlaying `AuditState` is "done," i.e. `pip-audit`'s
state is done changing. Implementors should override this to do nothing
if their state management requires no finalization step.
"""
raise NotImplementedError # pragma: no cover
class StatusLog: # pragma: no cover
"""
Displays a status indicator with an optional log panel to display logs
for external processes.
This code is based off of Rich's `Status` component:
https://github.com/Textualize/rich/blob/master/rich/status.py
"""
# NOTE(alex): We limit the panel to 10 characters high and display the last 10 log lines.
# However, the panel won't display all 10 of those lines if some of the lines are long enough
# to wrap in the panel.
LOG_PANEL_HEIGHT = 10
def __init__(
self,
status: str,
*,
console: Console | None = None,
spinner: str = "dots",
spinner_style: StyleType = "status.spinner",
speed: float = 1.0,
refresh_per_second: float = 12.5,
):
"""
Construct a new `StatusLog`.
`status` is the status message to display next to the spinner.
`console` is the Rich console to display the log status in.
`spinner` is the name of the spinner animation (see python -m rich.spinner). Defaults to `dots`.
`spinner_style` is the style of the spinner. Defaults to `status.spinner`.
`speed` is the speed factor for the spinner animation. Defaults to 1.0.
`refresh_per_second` is the number of refreshes per second. Defaults to 12.5.
"""
self._spinner = Spinner(spinner, text=status, style=spinner_style, speed=speed)
self._log_panel = Panel("", height=self.LOG_PANEL_HEIGHT)
self._live = Live(
self.renderable,
console=console,
refresh_per_second=refresh_per_second,
transient=True,
)
@property
def renderable(self) -> RenderableType:
"""
Create a Rich renderable type for the log panel.
If the log panel contains text, we should create a group and place the
log panel underneath the spinner.
"""
if self._log_panel.renderable:
return Group(self._spinner, self._log_panel)
return self._spinner
def update(
self,
status: str,
logs: str | None,
) -> None:
"""
Update status and logs.
"""
if logs is None:
logs = ""
else:
# Limit the logging output to the 10 most recent lines.
logs = "\n".join(logs.splitlines()[-self.LOG_PANEL_HEIGHT :])
self._spinner.update(text=status)
self._log_panel.renderable = logs
self._live.update(self.renderable, refresh=True)
def start(self) -> None:
"""
Start the status animation.
"""
self._live.start()
def stop(self) -> None:
"""
Stop the spinner animation.
"""
self._live.stop()
def __rich__(self) -> RenderableType:
"""
Convert to a Rich renderable type.
"""
return self.renderable
class AuditSpinner(_StateActor): # pragma: no cover
"""
A progress spinner for `pip-audit`, using `rich.status`'s spinner support
under the hood.
"""
def __init__(self, message: str = "") -> None:
"""
Initialize the `AuditSpinner`.
"""
self._console = Console()
# NOTE: audits can be quite fast, so we need a pretty high refresh rate here.
self._spinner = StatusLog(
message, console=self._console, spinner="line", refresh_per_second=30
)
# Keep the target set to `None` to ensure that the logs don't get written until the spinner
# has finished writing output, regardless of the capacity argument
self.log_handler = MemoryHandler(
0, flushLevel=logging.ERROR, target=None, flushOnClose=False
)
self.prev_handlers: list[logging.Handler] = []
def update_state(self, message: str, logs: str | None = None) -> None:
"""
Update the spinner's state.
"""
self._spinner.update(message, logs)
def initialize(self) -> None:
"""
Redirect logging to an in-memory log handler so that it doesn't get mixed in with the
spinner output.
"""
# Remove all existing log handlers
#
# We're recording them here since we'll want to restore them once the spinner falls out of
# scope
root_logger = logging.root
for handler in root_logger.handlers:
self.prev_handlers.append(handler)
for handler in self.prev_handlers:
root_logger.removeHandler(handler)
# Redirect logging to our in-memory handler that will buffer the log lines
root_logger.addHandler(self.log_handler)
self._spinner.start()
def finalize(self) -> None:
"""
Cleanup the spinner output so it doesn't get combined with subsequent `stderr` output and
flush any logs that were recorded while the spinner was active.
"""
self._spinner.stop()
# Now that the spinner is complete, flush the logs
root_logger = logging.root
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
self.log_handler.setTarget(stream_handler)
self.log_handler.flush()
# Restore the original log handlers
root_logger.removeHandler(self.log_handler)
for handler in self.prev_handlers:
root_logger.addHandler(handler)

View File

@@ -0,0 +1,67 @@
"""
A thin `subprocess` wrapper for making long-running subprocesses more
responsive from the `pip-audit` CLI.
"""
import os.path
import subprocess
from collections.abc import Sequence
from subprocess import Popen
from ._state import AuditState
class CalledProcessError(Exception):
"""
Raised if the underlying subprocess created by `run` exits with a nonzero code.
"""
def __init__(self, msg: str, *, stderr: str) -> None:
"""
Create a new `CalledProcessError`.
"""
super().__init__(msg)
self.stderr = stderr
def run(args: Sequence[str], *, log_stdout: bool = False, state: AuditState = AuditState()) -> str:
"""
Execute the given arguments.
Uses `state` to provide feedback on the subprocess's status.
Raises a `CalledProcessError` if the subprocess fails. Otherwise, returns
the process's `stdout` stream as a string.
"""
# NOTE(ww): We frequently run commands inside of ephemeral virtual environments,
# which have long absolute paths on some platforms. These make for confusing
# state updates, so we trim the first argument down to its basename.
pretty_args = " ".join([os.path.basename(args[0]), *args[1:]])
terminated = False
stdout = b""
stderr = b""
# Run the process with unbuffered I/O, to make the poll-and-read loop below
# more responsive.
with Popen(args, bufsize=0, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as process:
# NOTE: We use `poll()` to control this loop instead of the `read()` call
# to prevent deadlocks. Similarly, `read(size)` will return an empty bytes
# once `stdout` hits EOF, so we don't have to worry about that blocking.
while not terminated:
terminated = process.poll() is not None
stdout += process.stdout.read() # type: ignore
stderr += process.stderr.read() # type: ignore
state.update_state(
f"Running {pretty_args}",
stdout.decode(errors="replace") if log_stdout else None,
)
if process.returncode != 0:
raise CalledProcessError(
f"{pretty_args} exited with {process.returncode}",
stderr=stderr.decode(errors="replace"),
)
return stdout.decode("utf-8", errors="replace")

View File

@@ -0,0 +1,26 @@
"""
Utility functions for `pip-audit`.
"""
import sys
from typing import NoReturn # pragma: no cover
from packaging.version import Version
def assert_never(x: NoReturn) -> NoReturn: # pragma: no cover
"""
A hint to the typechecker that a branch can never occur.
"""
assert False, f"unhandled type: {type(x).__name__}"
def python_version() -> Version:
"""
Return a PEP-440-style version for the current Python interpreter.
This is more rigorous than `platform.python_version`, which can include
non-PEP-440-compatible data.
"""
info = sys.version_info
return Version(f"{info.major}.{info.minor}.{info.micro}")

View File

@@ -0,0 +1,215 @@
"""
Create virtual environments with a custom set of packages and inspect their dependencies.
"""
from __future__ import annotations
import json
import logging
import venv
from collections.abc import Iterator
from os import PathLike
from tempfile import NamedTemporaryFile, TemporaryDirectory, gettempdir
from types import SimpleNamespace
from packaging.version import Version
from ._state import AuditState
from ._subprocess import CalledProcessError, run
logger = logging.getLogger(__name__)
class VirtualEnv(venv.EnvBuilder):
"""
A wrapper around `EnvBuilder` that allows a custom `pip install` command to be executed, and its
resulting dependencies inspected.
The `pip-audit` API uses this functionality internally to deduce what the dependencies are for a
given requirements file since this can't be determined statically.
The `create` method MUST be called before inspecting the `installed_packages` property otherwise
a `VirtualEnvError` will be raised.
The expected usage is:
```
# Create a virtual environment and install the `pip-api` package.
ve = VirtualEnv(["pip-api"])
ve.create(".venv/")
for (name, version) in ve.installed_packages:
print(f"Installed package {name} ({version})")
```
"""
def __init__(
self,
install_args: list[str],
index_url: str | None = None,
extra_index_urls: list[str] = [],
state: AuditState = AuditState(),
):
"""
Create a new `VirtualEnv`.
`install_args` is the list of arguments that would be used the custom install command. For
example, if you wanted to execute `pip install -e /tmp/my_pkg`, you would create the
`VirtualEnv` like so:
```
ve = VirtualEnv(["-e", "/tmp/my_pkg"])
```
`index_url` is the base URL of the package index.
`extra_index_urls` are the extra URLs of package indexes.
`state` is an `AuditState` to use for state callbacks.
"""
super().__init__(with_pip=True)
self._install_args = install_args
self._index_url = index_url
self._extra_index_urls = extra_index_urls
self._packages: list[tuple[str, Version]] | None = None
self._state = state
def create(self, env_dir: str | bytes | PathLike[str] | PathLike[bytes]) -> None:
"""
Creates the virtual environment.
"""
try:
return super().create(env_dir)
except PermissionError:
# `venv` uses a subprocess internally to bootstrap pip, but
# some Linux distributions choose to mark the system temporary
# directory as `noexec`. Apart from having only nominal security
# benefits, this completely breaks our ability to execute from
# within the temporary virtualenv.
#
# We may be able to hack around this in the future, but doing so
# isn't straightforward or reliable. So we bail for now.
#
# See: https://github.com/pypa/pip-audit/issues/732
base_tmpdir = gettempdir()
raise VirtualEnvError(
f"Couldn't execute in a temporary directory under {base_tmpdir}. "
"This is sometimes caused by a noexec mount flag or other setting. "
"Consider changing this setting or explicitly specifying a different "
"temporary directory via the TMPDIR environment variable."
)
def post_setup(self, context: SimpleNamespace) -> None:
"""
Install the custom package and populate the list of installed packages.
This method is overridden from `EnvBuilder` to execute immediately after the virtual
environment has been created and should not be called directly.
We do a few things in our custom post-setup:
- Upgrade the `pip` version. We'll be using `pip list` with the `--format json` option which
requires a non-ancient version for `pip`.
- Install `wheel`. When our packages install their own dependencies, they might be able
to do so through wheels, which are much faster and don't require us to run
setup scripts.
- Execute the custom install command.
- Call `pip list`, and parse the output into a list of packages to be returned from when the
`installed_packages` property is queried.
"""
self._state.update_state("Updating pip installation in isolated environment")
# Firstly, upgrade our `pip` versions since `ensurepip` can leave us with an old version
# and install `wheel` in case our package dependencies are offered as wheels
# TODO: This is probably replaceable with the `upgrade_deps` option on `EnvBuilder`
# itself, starting with Python 3.9.
pip_upgrade_cmd = [
context.env_exe,
"-m",
"pip",
"install",
"--upgrade",
"pip",
"wheel",
"setuptools",
]
try:
run(pip_upgrade_cmd, state=self._state)
except CalledProcessError as cpe:
raise VirtualEnvError(f"Failed to upgrade `pip`: {pip_upgrade_cmd}") from cpe
self._state.update_state("Installing package in isolated environment")
with TemporaryDirectory() as ve_dir, NamedTemporaryFile(dir=ve_dir, delete=False) as tmp:
# We use delete=False in creating the tempfile to allow it to be
# closed and opened multiple times within the context scope on
# windows, see GitHub issue #646.
# Install our packages
# NOTE(ww): We pass `--no-input` to prevent `pip` from indefinitely
# blocking on user input for repository credentials, and
# `--keyring-provider=subprocess` to allow `pip` to access the `keyring`
# program on the `$PATH` for index credentials, if necessary. The latter flag
# is required beginning with pip 23.1, since `--no-input` disables the default
# keyring behavior.
package_install_cmd = [
context.env_exe,
"-m",
"pip",
"install",
"--no-input",
"--keyring-provider=subprocess",
*self._index_url_args,
"--dry-run",
"--report",
tmp.name,
*self._install_args,
]
try:
run(package_install_cmd, log_stdout=True, state=self._state)
except CalledProcessError as cpe:
# TODO: Propagate the subprocess's error output better here.
logger.error(f"internal pip failure: {cpe.stderr}")
raise VirtualEnvError(f"Failed to install packages: {package_install_cmd}") from cpe
self._state.update_state("Processing package list from isolated environment")
install_report = json.load(tmp)
package_list = install_report["install"]
# Convert into a series of name, version pairs
self._packages = []
for package in package_list:
package_metadata = package["metadata"]
self._packages.append(
(package_metadata["name"], Version(package_metadata["version"]))
)
@property
def installed_packages(self) -> Iterator[tuple[str, Version]]:
"""
A property to inspect the list of packages installed in the virtual environment.
This method can only be called after the `create` method has been called.
"""
if self._packages is None:
raise VirtualEnvError(
"Invalid usage of wrapper."
"The `create` method must be called before inspecting `installed_packages`."
)
yield from self._packages
@property
def _index_url_args(self) -> list[str]:
args = []
if self._index_url:
args.extend(["--index-url", self._index_url])
for index_url in self._extra_index_urls:
args.extend(["--extra-index-url", index_url])
return args
class VirtualEnvError(Exception):
"""
Raised when `VirtualEnv` fails to build or inspect dependencies, for any reason.
"""
pass