This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,13 @@
from .tool_inspector import ToolInspector
from .factory import tool_commands
from .main import configure_system, configure_alias
from .base import ToolResult
__all__ = [
"ToolInspector",
"tool_commands",
"configure_system",
"configure_alias",
"ToolResult",
]

View File

@@ -0,0 +1,68 @@
import base64
import json
import typer
from urllib.parse import urlsplit, urlunsplit
from safety.tool.constants import (
NPMJS_PUBLIC_REPOSITORY_URL,
PYPI_PUBLIC_REPOSITORY_URL,
)
from typing import Optional, Literal
def index_credentials(ctx: typer.Context) -> str:
"""
Returns the index credentials for the current context.
This should be used together with user:index_credential for index
basic auth.
Args:
ctx (typer.Context): The context.
Returns:
str: The index credentials.
"""
api_key = None
token = None
if auth := getattr(ctx.obj, "auth", None):
client = auth.client
token = client.token.get("access_token") if client.token else None
api_key = client.api_key
auth_envelop = json.dumps(
{
"version": "1.0",
"access_token": token,
"api_key": api_key,
"project_id": ctx.obj.project.id if ctx.obj.project else None,
}
)
return base64.urlsafe_b64encode(auth_envelop.encode("utf-8")).decode("utf-8")
def build_index_url(
ctx: typer.Context, index_url: Optional[str], index_type: Literal["pypi", "npm"]
) -> str:
"""
Builds the index URL for the current context.
"""
if index_url is None:
# TODO: Make this to select the index based on auth org or project
index_url = {
"pypi": PYPI_PUBLIC_REPOSITORY_URL,
"npm": NPMJS_PUBLIC_REPOSITORY_URL,
}[index_type]
url = urlsplit(index_url)
encoded_auth = index_credentials(ctx)
netloc = f"user:{encoded_auth}@{url.netloc}"
if type(url.netloc) is bytes:
url = url._replace(netloc=netloc.encode("utf-8"))
elif type(url.netloc) is str:
url = url._replace(netloc=netloc)
return urlunsplit(url)

View File

@@ -0,0 +1,626 @@
from abc import ABC, abstractmethod
import json
import sys
from pathlib import Path
import shutil
import subprocess
import time
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Literal, Mapping
from dataclasses import dataclass
import typer
from safety.events.utils import emit_tool_command_executed
from safety.models import ToolResult
from safety.tool.constants import (
PROJECT_CONFIG,
MOST_FREQUENTLY_DOWNLOADED_PYPI_PACKAGES,
)
from safety.tool.typosquatting import TyposquattingProtection
from safety.utils.pyapp_utils import get_env
from .environment_diff import EnvironmentDiffTracker
from .intents import CommandToolIntention, ToolIntentionType, Dependency
from .resolver import get_unwrapped_command
from safety_schemas.models.events.types import ToolType
from safety.events.utils import emit_diff_operations
from .utils import (
is_os_supported,
)
import logging
logger = logging.getLogger(__name__)
class BaseCommand(ABC):
"""
Abstract base class for tool commands.
Requires subclasses to implement all required attributes.
"""
def __init__(
self,
args: List[str],
capture_output: bool = False,
intention: Optional[CommandToolIntention] = None,
command_alias_used: Optional[str] = None,
) -> None:
"""
Initialize the command.
Args:
args: Command arguments
capture_output: Whether to capture command output
"""
self._args = args
self._intention = intention
self._capture_output = capture_output
self._command_alias_used = command_alias_used
self._tool_type = self.get_tool_type()
self.__typosquatting_protection = self._build_typosquatting_protection()
self._diff_tracker = self.get_diff_tracker()
self._should_track_state = self.should_track_state()
@abstractmethod
def get_tool_type(self) -> ToolType:
"""
Get the tool type for this command type.
Must be implemented by subclasses.
Returns:
ToolType: Tool type
"""
pass
@abstractmethod
def get_command_name(self) -> List[str]:
"""
Get the command name for this command type.
Must be implemented by subclasses.
Returns:
List[str]: Command name as a list (e.g. ["pip"])
"""
pass
def get_ecosystem(self) -> Literal["pypi", "npmjs"]:
"""
Get the ecosystem for this command type.
Must be implemented by subclasses.
Returns:
Literal["pypi", "npmjs"]: Ecosystem
"""
return "pypi"
@abstractmethod
def get_diff_tracker(self) -> EnvironmentDiffTracker:
"""
Get the diff tracker instance for this command type.
Must be implemented by subclasses.
Returns:
EnvironmentDiffTracker: Diff tracker instance
"""
pass
def should_track_state(self) -> bool:
"""
Determine if this command should track state changes.
Subclasses can override for more sophisticated logic.
Returns:
bool: True if state changes should be tracked
"""
if self._intention:
return self._intention.modifies_packages()
return False
def _get_typosquatting_reference_packages(self) -> Tuple[str]:
"""
Return the corpus of well-known package names used by the
TypoSquatting protection to validate/correct package names.
Child classes should override this if they target a different
package ecosystem (e.g., npm) or want a custom corpus.
Returns:
Tuple[str]: Default set of popular PyPI package names.
"""
return MOST_FREQUENTLY_DOWNLOADED_PYPI_PACKAGES
def _build_typosquatting_protection(self) -> TyposquattingProtection:
"""
Factory method for the TypoSquatting protection instance.
Child classes may override this to customize the protection
strategy entirely (not only the corpus), if needed.
Returns:
TyposquattingProtection: Configured protection instance.
"""
return TyposquattingProtection(self._get_typosquatting_reference_packages())
def get_package_list_command(self) -> List[str]:
"""
Get the command to list installed packages.
Subclasses must override this to provide the correct command.
Returns:
List[str]: Command to list packages in JSON format
"""
# Default implementation, should be overridden by subclasses
return [*self.get_command_name(), "list", "-v", "--format=json"]
def parse_package_list_output(self, output: str) -> List[Dict[str, Any]]:
"""
Parse the output of the package list command.
Subclasses can override this for custom parsing logic.
Args:
output: Command output
Returns:
List[Dict[str, Any]]: List of package dictionaries
"""
# Default implementation assumes JSON output
try:
return json.loads(output)
except json.JSONDecodeError:
# Log error and return empty list
logger.exception(f"Error parsing package list output: {output[:100]}...")
return []
def _initialize_diff_tracker(self, ctx: typer.Context):
"""
Common implementation to initialize the diff tracker.
Can be called by child classes in their before() implementation.
"""
current_packages = self._get_installed_packages(ctx)
self._diff_tracker.set_before_state(current_packages)
def __run_scan_if_needed(self, ctx: typer.Context, silent: bool = True):
if not is_os_supported():
return
target = Path.cwd()
if (target / PROJECT_CONFIG).is_file():
if silent:
self.__run_silent_scan(ctx, target)
else:
from safety.init.command import init_scan_ui
init_scan_ui(ctx, prompt_user=True)
def __run_silent_scan(self, ctx: typer.Context, target: Path):
"""
Run a scan silently without displaying progress.
"""
target_arg = str(target.resolve())
CMD = ("safety", "scan", "--target", target_arg)
logger.info(f"Launching silent scan: {CMD}")
try:
kwargs = {
"stdout": subprocess.DEVNULL,
"stderr": subprocess.DEVNULL,
"stdin": subprocess.DEVNULL,
"shell": False,
}
if sys.platform == "win32":
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
else:
kwargs["start_new_session"] = True
subprocess.Popen(CMD, **kwargs)
except Exception as e:
logger.error(f"Failed to start independent scan: {e}")
def _handle_command_result(self, ctx: typer.Context, result: ToolResult):
"""
Common implementation to handle command results.
Can be called by child classes in their after() implementation.
"""
process = result.process
if process:
if process.returncode == 0 and self._should_track_state:
self._perform_diff(ctx, result.tool_path)
self.__run_scan_if_needed(ctx, silent=True)
emit_tool_command_executed(
ctx.obj.event_bus,
ctx, # type: ignore
tool=self._tool_type,
result=result,
)
def is_installed(self) -> bool:
"""
Checks if the tool program is reachable
Returns:
True if the tool is reachable on system, or false otherwise
"""
cmd_name = self.get_command_name()[0]
return shutil.which(cmd_name) is not None
def before(self, ctx: typer.Context):
if self._should_track_state:
self._initialize_diff_tracker(ctx)
if (
self._intention
and self._intention.packages
and self._intention.intention_type is not ToolIntentionType.REMOVE_PACKAGE
):
for dep in self._intention.packages:
if reviewed_name := self.__typosquatting_protection.coerce(
self._intention, dep.name
):
dep.corrected_text = dep.original_text.replace(
dep.name, reviewed_name
)
# NOTE: Mutation here is a workaround, it should be improved in the future.
dep.name = reviewed_name
self._args[dep.arg_index] = dep.corrected_text
def after(self, ctx: typer.Context, result: ToolResult):
self._handle_command_result(ctx, result)
def execute(self, ctx: typer.Context) -> ToolResult:
self.before(ctx)
# TODO: Safety should redirect to the proper pip/tool, if the user is
# using pip3, it should be redirected to pip3, not pip to avoid any
# issues.
cmd = self.get_command_name()
cmd_name = cmd[0]
logger.debug(f"Getting unwrapped command for: {cmd_name}")
tool_path = get_unwrapped_command(name=cmd_name)
logger.debug(f"Resolved tool_path: {tool_path}")
pre_args = [tool_path] + cmd[1:]
args = pre_args + self.__remove_safety_args(self._args)
logger.debug(f"Final command args: {args}")
started_at = time.monotonic()
logger.debug(f"Running subprocess with capture_output={self._capture_output}")
process = subprocess.run(
args, capture_output=self._capture_output, env=self.env(ctx)
)
logger.debug(f"Subprocess completed with returncode: {process.returncode}")
duration_ms = int((time.monotonic() - started_at) * 1000)
result = ToolResult(
process=process, duration_ms=duration_ms, tool_path=tool_path
)
self.after(ctx, result)
return result
def env(self, ctx: typer.Context):
"""
Returns the environment.
Args:
ctx (typer.Context): The context.
Returns:
dict: The environment.
"""
return get_env()
def __remove_safety_args(self, args: List[str]):
return [arg for arg in args if not arg.startswith("--safety")]
def _get_installed_packages(self, ctx: typer.Context) -> List[Dict[str, Any]]:
"""
Get currently installed packages
"""
command = self.get_package_list_command()
base_cmd = [get_unwrapped_command(name=command[0])]
args = base_cmd + command[1:]
result = subprocess.run(args, capture_output=True, env=self.env(ctx), text=True)
return self.parse_package_list_output(result.stdout)
def _perform_diff(self, ctx: typer.Context, tool_path: Optional[str] = None):
"""
Perform the diff operation.
Can be called by child classes when appropriate.
"""
current_packages = self._get_installed_packages(ctx)
self._diff_tracker.set_after_state(current_packages)
added, removed, updated = self._diff_tracker.get_diff()
emit_diff_operations(
ctx.obj.event_bus,
ctx, # type: ignore
added=added,
removed=removed,
updated=updated,
tool_path=tool_path,
by_tool=self._tool_type,
)
@dataclass
class ParsedCommand:
"""
Represents a parsed command with its hierarchy
"""
chain: List[str] # e.g., ['pip', 'install'] or ['add']
intention: ToolIntentionType
remaining_args_start: int # Where options/packages start
class ToolCommandLineParser(ABC):
"""
Base implementation of a command line parser for tools
"""
def __init__(self):
self._tool_name = self.get_tool_name()
@abstractmethod
def get_tool_name(self) -> str:
pass
@abstractmethod
def get_command_hierarchy(self) -> Mapping[str, Union[ToolIntentionType, Mapping]]:
"""
Return command hierarchy only. No option definitions needed.
Example:
{
'add': ToolIntentionType.ADD_PACKAGE,
'pip': {
'install': ToolIntentionType.ADD_PACKAGE,
'uninstall': ToolIntentionType.REMOVE_PACKAGE
}
}
"""
pass
@abstractmethod
def get_known_flags(self) -> Dict[str, Set[str]]:
"""
Return known flags that don't take values.
Format: {command_path: {flag_names}}
Example:
{
'global': {'verbose', 'v', 'quiet', 'q', 'help', 'h'},
'install': {'upgrade', 'U', 'dry-run', 'no-deps', 'user'}
}
"""
pass
def parse(
self, args: List[str], start_from: int = 0
) -> Optional[CommandToolIntention]:
"""
Main parsing method
"""
parsed_command = self._parse_command_hierarchy(args, start_from)
if not parsed_command:
return None
remaining_args = args[parsed_command.remaining_args_start :]
options, packages = self._parse_options_and_packages(
remaining_args, parsed_command
)
return CommandToolIntention(
tool=self._tool_name,
command=" ".join(parsed_command.chain),
command_chain=parsed_command.chain,
intention_type=parsed_command.intention,
packages=packages,
options=options,
raw_args=args.copy(),
)
def _is_known_flag(self, option_key: str, command_chain: List[str]) -> bool:
"""
Check if option is a known flag using command context
"""
known_flags = self.get_known_flags()
# Try command-specific flags first, then global
candidates = []
if command_chain:
for i in range(len(command_chain), 0, -1):
candidates.append(".".join(command_chain[:i]))
candidates.append("global")
for candidate in candidates:
if candidate in known_flags and option_key in known_flags[candidate]:
return True
return False
def _parse_command_hierarchy(
self, args: List[str], start_from: int
) -> Optional[ParsedCommand]:
"""
Parse the command hierarchy - stop at first non-command
"""
if not args or start_from >= len(args):
return None
hierarchy = self.get_command_hierarchy()
command_chain = []
current_level = hierarchy
i = start_from
while i < len(args):
arg = args[i].lower()
# Check if this argument is a valid command at current level
if isinstance(current_level, Mapping) and arg in current_level:
command_chain.append(arg)
current_level = current_level[arg]
# If we hit an intention type, we're done with commands
if isinstance(current_level, ToolIntentionType):
return ParsedCommand(
chain=command_chain,
intention=current_level,
remaining_args_start=i + 1,
)
i += 1
# Check if we ended on a valid intention
if isinstance(current_level, ToolIntentionType):
return ParsedCommand(
chain=command_chain, intention=current_level, remaining_args_start=i
)
return None
def _parse_options_and_packages(
self, args: List[str], parsed_command: ParsedCommand
) -> Tuple[Dict[str, Any], List[Dependency]]:
"""
Simple parsing: hyphens = options, everything else = packages/args
"""
options = {}
packages = []
i = 0
while i < len(args):
arg = args[i]
if arg.startswith("-"):
option_key, option_data, consumed = self._parse_option(
args, i, parsed_command
)
options[option_key] = option_data
i += consumed
else:
arg_index = parsed_command.remaining_args_start + i
dep = self._try_parse_package(arg, arg_index, parsed_command)
if dep:
packages.append(dep)
else:
self._store_unknown_argument(options, arg, arg_index)
i += 1
return options, packages
def _parse_option(
self, args: List[str], i: int, parsed_command: ParsedCommand
) -> Tuple[str, Dict[str, Any], int]:
"""
Parse a single option, args[i] is expected to be a hyphenated option
"""
arg = args[i]
arg_index = parsed_command.remaining_args_start + i
# Handle --option=value format
if "=" in arg:
option_part, value_part = arg.split("=", 1)
option_key = option_part.lstrip("-")
option_data = {
"arg_index": arg_index,
"raw_option": option_part,
"value": value_part,
}
return option_key, option_data, 1
# Handle --option, -option formats for known flags
option_key = arg.lstrip("-")
if self._is_known_flag(option_key, parsed_command.chain):
# It's a flag - doesn't take value
option_data = {
"arg_index": arg_index,
"raw_option": arg,
"value": True,
}
return option_key, option_data, 1
# Handle --option value, -option value formats
if i + 1 < len(args) and not args[i + 1].startswith("-"):
option_data = {
"arg_index": arg_index,
"raw_option": arg,
"value": args[i + 1],
"value_index": arg_index + 1,
}
return option_key, option_data, 2
# Handle --option, -option formats for unknown flags
option_data = {
"arg_index": arg_index,
"raw_option": arg,
"value": True,
}
return option_key, option_data, 1
def _should_parse_as_package(self, intention: ToolIntentionType) -> bool:
"""
Check if arguments should be parsed as packages
"""
return intention in [
ToolIntentionType.ADD_PACKAGE,
ToolIntentionType.REMOVE_PACKAGE,
ToolIntentionType.DOWNLOAD_PACKAGE,
ToolIntentionType.SEARCH_PACKAGES,
]
def _try_parse_package(
self, arg: str, index: int, parsed_command: ParsedCommand
) -> Optional[Dependency]:
"""
Try to parse argument as package, return None if fails
"""
if self._should_parse_as_package(parsed_command.intention):
return self._parse_package_spec(arg, index)
return None
def _store_unknown_argument(self, options: Dict, arg: str, index: int):
"""
Store non-package arguments in options as unknown
"""
key = f"unknown_{len([k for k in options.keys() if k.startswith('unknown_')])}"
options[key] = {
"arg_index": index,
"value": arg,
}
def _parse_package_spec(
self, spec_str: str, arg_index: int
) -> Optional[Dependency]:
try:
from packaging.requirements import Requirement
# TODO: pip install . should be excluded
req = Requirement(spec_str)
return Dependency(
name=req.name,
version_constraint=str(req.specifier),
extras=req.extras,
arg_index=arg_index,
original_text=spec_str,
)
except Exception:
# If spec parsing fails, just ignore for now
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,57 @@
from functools import wraps
from pathlib import Path
from rich.console import Console
from safety_schemas.models import ProjectModel
from safety.console import main_console
from safety.tool.constants import (
MSG_NOT_AUTHENTICATED_TOOL,
MSG_NOT_AUTHENTICATED_TOOL_NO_TTY,
)
from ..codebase_utils import load_unverified_project_from_config
from ..scan.util import GIT
def prepare_tool_execution(func):
@wraps(func)
def inner(ctx, target: Path, *args, **kwargs):
ctx.obj.console = main_console
ctx.params.pop("console", None)
if not ctx.obj.auth.is_valid():
tool_name = ctx.command.name.title()
if ctx.obj.console.is_interactive:
ctx.obj.console.line()
ctx.obj.console.print(
MSG_NOT_AUTHENTICATED_TOOL.format(tool_name=tool_name)
)
ctx.obj.console.line()
from safety.cli_util import process_auth_status_not_ready
process_auth_status_not_ready(
console=main_console, auth=ctx.obj.auth, ctx=ctx
)
else:
stderr_console = Console(stderr=True)
stderr_console.print(
MSG_NOT_AUTHENTICATED_TOOL_NO_TTY.format(tool_name=tool_name)
)
unverified_project = load_unverified_project_from_config(project_root=target)
if prj_id := unverified_project.id:
ctx.obj.project = ProjectModel(
id=prj_id,
name=unverified_project.name,
project_path=unverified_project.project_path,
)
git_data = GIT(root=target).build_git_data()
ctx.obj.project.git = git_data
return func(ctx, target=target, *args, **kwargs)
return inner

View File

@@ -0,0 +1,98 @@
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, Field
from safety.cli_util import CommandType, FeatureType
from safety.constants import CONTEXT_COMMAND_TYPE, CONTEXT_FEATURE_TYPE
class ContextSettingsModel(BaseModel):
"""
Model for command context settings.
"""
allow_extra_args: bool = Field(default=True)
ignore_unknown_options: bool = Field(default=True)
command_type: CommandType = Field(default=CommandType.BETA)
feature_type: FeatureType = Field(default=FeatureType.FIREWALL)
help_option_names: List[str] = Field(default=["--safety-help"])
def as_dict(self) -> Dict[str, Any]:
"""
Convert to dictionary format expected by Typer.
Returns:
Dict[str, Any]: Dictionary representation of the context settings
"""
result = {
"allow_extra_args": self.allow_extra_args,
"ignore_unknown_options": self.ignore_unknown_options,
CONTEXT_COMMAND_TYPE: self.command_type,
CONTEXT_FEATURE_TYPE: self.feature_type,
"help_option_names": self.help_option_names,
}
return result
class CommandSettingsModel(BaseModel):
"""
Model for command settings used in the Typer decorator.
"""
help: str
name: str
options_metavar: str = Field(default="[OPTIONS]")
context_settings: ContextSettingsModel = Field(default_factory=ContextSettingsModel)
class ToolCommandModel(BaseModel):
"""
Model for a tool command definition.
"""
name: str
display_name: str
help: str
# Path to custom Typer app if available
custom_app: Optional[str] = None
# Custom command settings for the tool
command_settings: Optional[CommandSettingsModel] = None
def get_command_settings(self) -> CommandSettingsModel:
"""
Get command settings, using defaults if not specified.
Returns:
CommandSettingsModel: Command settings with defaults
"""
if self.command_settings:
return self.command_settings
return CommandSettingsModel(
help=self.help,
name=self.display_name,
)
# Tool definitions
TOOLS = [
ToolCommandModel(
name="poetry",
display_name="poetry",
help="[BETA] Run poetry commands protected by Safety firewall.\nExample: safety poetry add httpx",
),
ToolCommandModel(
name="pip",
display_name="pip",
help="[BETA] Run pip commands protected by Safety firewall.\nExample: safety pip list",
),
ToolCommandModel(
name="uv",
display_name="uv",
help="[BETA] Run uv commands protected by Safety firewall.\nExample: safety uv pip list",
),
ToolCommandModel(
name="npm",
display_name="npm",
help="[BETA] Run npm commands protected by Safety firewall.\nExample: safety npm list",
),
]

View File

@@ -0,0 +1,165 @@
from typing import Dict, List, Tuple, Any, Callable, TypeVar, Generic, NamedTuple
from packaging.utils import canonicalize_name, canonicalize_version
T = TypeVar("T") # For the package data type
K = TypeVar("K") # For the key type
V = TypeVar("V") # For the value type
class PackageLocation(NamedTuple):
"""
Composite key representing package name and location.
"""
name: str
location: str
class EnvironmentDiffTracker(Generic[T, K, V]):
"""
Generic utility class to track changes in environment states before and
after operations. Can be used with any environment management system
(pip, npm, apt, docker, etc.).
"""
def __init__(
self,
key_extractor: Callable[[T], K],
value_extractor: Callable[[T], V],
) -> None:
"""
Initialize a new environment diff tracker.
Args:
key_extractor: Function to extract the item identifier from an entry
value_extractor: Function to extract the version or other value to
compare
normalize_key: Optional function to normalize keys
(e.g., make lowercase)
"""
self._key_extractor = key_extractor
self._value_extractor = value_extractor
self._before_items: Dict[K, V] = {}
self._after_items: Dict[K, V] = {}
def set_before_state(self, items_data: List[T]) -> None:
"""
Set the before-operation environment state.
Args:
items_data: List of items in the format specific to the environment
"""
self._before_items = self._normalize_items_data(items_data)
def set_after_state(self, items_data: List[T]) -> None:
"""
Set the after-operation environment state.
Args:
items_data: List of items in the format specific to the environment
"""
self._after_items = self._normalize_items_data(items_data)
def get_diff(self) -> Tuple[Dict[K, V], Dict[K, V], Dict[K, Tuple[V, V]]]:
"""
Compute the difference between before and after environment states.
Returns:
Tuple containing:
- Dictionary of added items {key: value}
- Dictionary of removed items {key: value}
- Dictionary of updated items {key: (old_value, new_value)}
"""
before_keys = set(self._before_items.keys())
after_keys = set(self._after_items.keys())
# Find added and removed items
added_keys = after_keys - before_keys
removed_keys = before_keys - after_keys
# Find updated items (same key, different value)
common_keys = before_keys & after_keys
updated_keys = {
key: (self._before_items[key], self._after_items[key])
for key in common_keys
if self._before_items[key] != self._after_items[key]
}
# Create result dictionaries
added = {key: self._after_items[key] for key in added_keys}
removed = {key: self._before_items[key] for key in removed_keys}
updated = {key: updated_keys[key] for key in updated_keys}
return added, removed, updated
def _normalize_items_data(self, items_data: List[T]) -> Dict[K, V]:
"""
Normalize items data into a standardized dictionary format.
Args:
items_data: List of item data entries
Returns:
Dict mapping normalized item keys to their values
"""
result = {}
for item_info in items_data:
try:
key = self._key_extractor(item_info)
value = self._value_extractor(item_info)
result[key] = value
except (KeyError, TypeError, AttributeError):
# Skip entries that don't have the expected structure
continue
return result
class PipEnvironmentDiffTracker(
EnvironmentDiffTracker[Dict[str, Any], PackageLocation, str]
):
"""
Specialized diff tracker for pip package environments.
"""
def __init__(self):
super().__init__(
key_extractor=self._pip_key_extractor,
value_extractor=self._pip_value_extractor,
)
# TODO: handle errors in value extraction
def _pip_key_extractor(self, pkg: Dict[str, Any]) -> PackageLocation:
return PackageLocation(
name=canonicalize_name(pkg.get("name", "")),
location=pkg.get("location", ""),
)
def _pip_value_extractor(self, pkg: Dict[str, Any]) -> str:
return canonicalize_version(pkg.get("version", ""), strip_trailing_zero=False)
class NpmEnvironmentDiffTracker(
EnvironmentDiffTracker[Dict[str, Any], PackageLocation, str]
):
"""
Specialized diff tracker for npm package environments.
"""
def __init__(self):
super().__init__(
key_extractor=self._npm_key_extractor,
value_extractor=self._npm_value_extractor,
)
# TODO: handle errors in value extraction
def _npm_key_extractor(self, pkg: Dict[str, Any]) -> PackageLocation:
return PackageLocation(
name=pkg.get("name", ""),
location=pkg.get("location", ""),
)
def _npm_value_extractor(self, pkg: Dict[str, Any]) -> str:
return pkg.get("version", "")

View File

@@ -0,0 +1,185 @@
"""
Factory for creating and registering package manager commands.
"""
import importlib
import logging
from pathlib import Path
import sys
from typing import TYPE_CHECKING, cast
import typer
from safety.decorators import notify
from safety.error_handlers import handle_cmd_exception
from safety.tool.decorators import prepare_tool_execution
from .definitions import TOOLS, ToolCommandModel
try:
from typing import Annotated # type: ignore[import]
except ImportError:
from typing_extensions import Annotated
if TYPE_CHECKING:
from safety.cli_util import SafetyCLILegacyGroup
from safety.tool import ToolResult
from safety.cli_util import CustomContext
logger = logging.getLogger(__name__)
class ToolCommandFactory:
"""
Factory for creating command apps per tool.
"""
def _get_command_class_name(self, pkg_name: str) -> str:
"""
Get the command class name for a package manager.
Args:
pkg_name: Name of the package manager
Returns:
str: Command class name
"""
return f"{pkg_name.capitalize()}Command"
def _create_tool_group(
self,
*,
tool_command: ToolCommandModel,
command_class_name: str,
) -> typer.Typer:
"""
Create a standard app for a package manager based on tool command model.
Args:
tool_command: Tool command model with configuration
command_class_name: Name of the command class
Returns:
typer.Typer: The created Typer group
"""
# Get command settings from the tool command model
cmd_settings = tool_command.get_command_settings()
from safety.cli_util import SafetyCLICommand, SafetyCLISubGroup
app = typer.Typer(rich_markup_mode="rich", cls=SafetyCLISubGroup)
# Main command
@app.command(
cls=SafetyCLICommand,
help=cmd_settings.help,
name=cmd_settings.name,
options_metavar=cmd_settings.options_metavar,
context_settings=cmd_settings.context_settings.as_dict(),
)
@handle_cmd_exception
@prepare_tool_execution
@notify
def tool_main_command(
ctx: typer.Context,
target: Annotated[
Path,
typer.Option(
exists=True,
file_okay=False,
dir_okay=True,
writable=False,
readable=True,
resolve_path=True,
show_default=False,
), # type: ignore
] = Path("."),
):
"""
Base command handler that forwards to the appropriate command class.
Args:
ctx: Typer context
"""
# Get the command class directly using importlib
module_name = f"safety.tool.{tool_command.name}.command"
try:
module = importlib.import_module(module_name)
command_class = getattr(module, command_class_name, None)
except ImportError:
logger.error(f"Could not import {module_name}")
command_class = None
if not command_class:
typer.echo(f"Command class {command_class_name} not found")
return
parent_ctx = cast("CustomContext", ctx.parent)
command = command_class.from_args(
ctx.args,
command_alias_used=parent_ctx.command_alias_used,
)
if not command.is_installed():
typer.echo(f"Tool {tool_command.name} is not installed.")
sys.exit(1)
result: "ToolResult" = command.execute(ctx)
if result.process.returncode != 0:
sys.exit(result.process.returncode)
# We can support subcommands in the future
return app
def auto_register_tools(self, group: "SafetyCLILegacyGroup") -> None:
"""
Auto-register commands from the definitions configuration.
Args:
group: The main Safety CLI group
Returns:
Dict[str, typer.Typer]: Dictionary of registered apps
"""
for tool_command_config in TOOLS:
tool_name = tool_command_config.name
# Get the command class name
command_class_name = self._get_command_class_name(tool_name)
tool_app = None
# First check if custom_app is specified in the tool model
if tool_command_config.custom_app:
try:
module_path, attr_name = tool_command_config.custom_app.rsplit(
".", 1
)
module = importlib.import_module(module_path)
tool_app = getattr(module, attr_name, None)
if not tool_app:
logger.error(
f"Custom app {attr_name} not found in {module_path}"
)
except (ImportError, AttributeError, ValueError) as e:
logger.exception(
f"Failed to import custom app for {tool_name}: {e}"
)
# If no custom_app or it failed, create the tool app
if not tool_app:
tool_app = self._create_tool_group(
tool_command=tool_command_config,
command_class_name=command_class_name,
)
# We can support subcommands in the future
# Register the tool app
group.add_command(typer.main.get_command(tool_app), name=tool_name)
logger.info(f"Registered auto-generated command for {tool_name}")
tool_commands = ToolCommandFactory()

View File

@@ -0,0 +1,71 @@
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Dict, List, Optional, Set
class ToolIntentionType(Enum):
"""
High-level intentions that are common across tools
"""
ADD_PACKAGE = auto()
REMOVE_PACKAGE = auto()
UPDATE_PACKAGE = auto()
DOWNLOAD_PACKAGE = auto()
SEARCH_PACKAGES = auto()
SYNC_PACKAGES = auto()
LIST_PACKAGES = auto()
INIT_PROJECT = auto()
BUILD_PROJECT = auto()
RUN_SCRIPT = auto()
UNKNOWN = auto()
@dataclass
class Dependency:
"""
Common representation of a dependency
"""
name: str
arg_index: int
original_text: str
version_constraint: Optional[str] = None
extras: Set[str] = field(default_factory=set)
is_dev_dependency: bool = False
corrected_text: Optional[str] = None
@dataclass
class CommandToolIntention:
"""
Represents a parsed tool command with normalized intention
"""
tool: str
command: str
intention_type: ToolIntentionType
command_chain: List[str] = field(default_factory=list)
packages: List[Dependency] = field(default_factory=list)
options: Dict[str, Any] = field(default_factory=dict)
raw_args: List[str] = field(default_factory=list)
def modifies_packages(self) -> bool:
"""
Check if this intention type modifies installed packages.
"""
return self.intention_type in {
ToolIntentionType.ADD_PACKAGE,
ToolIntentionType.REMOVE_PACKAGE,
ToolIntentionType.UPDATE_PACKAGE,
ToolIntentionType.SYNC_PACKAGES,
}
def queries_packages(self) -> bool:
"""
Check if this intention type queries for packages.
"""
return self.intention_type in {
ToolIntentionType.SEARCH_PACKAGES,
ToolIntentionType.LIST_PACKAGES,
}

View File

@@ -0,0 +1,4 @@
from .types import InterceptorType
from .factory import create_interceptor
__all__ = ["InterceptorType", "create_interceptor"]

View File

@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import List, Dict, Optional, Tuple
from .types import InterceptorType
from safety.meta import get_version
@dataclass
class Tool:
name: str
binary_names: List[str]
# TODO: Add Event driven output and support --safety-ping flag to test the
# interceptors status.
class CommandInterceptor(ABC):
"""
Abstract base class for command interceptors.
This class provides a framework for installing and removing interceptors
for various tools. Subclasses must implement the `_batch_install_tools`
and `_batch_remove_tools` methods to handle the actual installation and
removal processes.
Attributes:
interceptor_type (InterceptorType): The type of the interceptor.
tools (Dict[str, Tool]): A dictionary mapping tool names to Tool
objects.
Note:
All method implementations should be idempotent.
"""
def __init__(self, interceptor_type: InterceptorType):
self.interceptor_type = interceptor_type
self.tools: Dict[str, Tool] = {
"pip": Tool(
"pip", ["pip", "pip3"] + [f"pip3.{ver}" for ver in range(8, 15)]
),
"poetry": Tool("poetry", ["poetry"]),
"uv": Tool("uv", ["uv"]),
"npm": Tool("npm", ["npm"]),
}
@abstractmethod
def _batch_install_tools(self, tools: List[Tool]) -> bool:
"""
Install multiple tools at once. Must be implemented by subclasses.
"""
pass
@abstractmethod
def _batch_remove_tools(self, tools: List[Tool]) -> bool:
"""
Remove multiple tools at once. Must be implemented by subclasses.
"""
pass
def install_interceptors(self, tools: Optional[List[str]] = None) -> bool:
"""
Install interceptors for the specified tools or all tools if none
specified.
"""
tools_to_install = self._get_tools(tools)
return self._batch_install_tools(tools_to_install)
def remove_interceptors(self, tools: Optional[List[str]] = None) -> bool:
"""
Remove interceptors for the specified tools or all tools if none
specified.
"""
tools_to_remove = self._get_tools(tools)
return self._batch_remove_tools(tools_to_remove)
def _get_tools(self, tools: Optional[List[str]] = None) -> List[Tool]:
"""
Get list of Tool objects based on tool names.
"""
if tools is None:
return list(self.tools.values())
return [self.tools[name] for name in tools if name in self.tools]
def _generate_metadata_content(self, prepend: str) -> Tuple[str, str, str]:
"""
Create metadata for the files that are managed by us.
"""
return (
f"{prepend} DO NOT EDIT THIS FILE DIRECTLY",
f"{prepend} Last updated at: {datetime.now(timezone.utc).isoformat()}",
f"{prepend} Updated by: safety v{get_version()}",
)

View File

@@ -0,0 +1,31 @@
from sys import platform
from typing import Optional
from .types import InterceptorType
from .unix import UnixAliasInterceptor
from .windows import WindowsInterceptor
from .base import CommandInterceptor
def create_interceptor(
interceptor_type: Optional[InterceptorType] = None,
) -> CommandInterceptor:
"""
Create appropriate interceptor based on OS and type
"""
interceptor_map = {
InterceptorType.UNIX_ALIAS: UnixAliasInterceptor,
InterceptorType.WINDOWS_BAT: WindowsInterceptor,
}
if interceptor_type:
return interceptor_map[interceptor_type]()
# Auto-select based on OS
if platform == "win32":
return interceptor_map[InterceptorType.WINDOWS_BAT]()
if platform in ["linux", "linux2", "darwin"]:
# Default to alias-based on Unix-like systems
return interceptor_map[InterceptorType.UNIX_ALIAS]()
raise NotImplementedError(f"Platform '{platform}' is not supported.")

View File

@@ -0,0 +1,6 @@
from enum import Enum, auto
class InterceptorType(Enum):
UNIX_ALIAS = auto()
WINDOWS_BAT = auto()

View File

@@ -0,0 +1,204 @@
import logging
from pathlib import Path
import re
import shutil
import tempfile
from typing import List
from .base import CommandInterceptor, Tool
from .types import InterceptorType
from safety.constants import USER_CONFIG_DIR
logger = logging.getLogger(__name__)
class UnixAliasInterceptor(CommandInterceptor):
def __init__(self):
super().__init__(InterceptorType.UNIX_ALIAS)
self.user_rc_paths: List[Path] = self._get_user_rc_paths()
self.custom_rc_path = self._get_custom_rc_path()
self.legacy_user_rc_paths = [Path.home() / ".profile"]
# Update these markers could be a breaking change; be careful to handle
# backward compatibility
self.marker_start = "# >>> Safety >>>"
self.marker_end = "# <<< Safety <<<"
def _get_user_rc_paths(self) -> List[Path]:
"""
We support the following shells:
* Zsh
* Bash
"""
zsh_paths = [Path.home() / ".zshrc"]
# .bash_profile is added for max compatibility on macOS
bash_profile = Path.home() / ".bash_profile"
bashrc = Path.home() / ".bashrc"
profile = Path.home() / ".profile"
bash_paths = [bash_profile]
if bash_profile.exists():
bash_paths = [bash_profile]
elif bashrc.exists():
bash_paths = [bashrc]
elif profile.exists():
bash_paths = [profile]
return zsh_paths + bash_paths
def _get_custom_rc_path(self) -> Path:
return USER_CONFIG_DIR / ".safety_profile"
def _backup_file(self, path: Path) -> None:
"""
Create backup of file if it exists
"""
if path.exists():
backup_path = path.with_suffix(".backup")
shutil.copy2(path, backup_path)
def _generate_user_rc_content(self) -> str:
"""
Generate the content to be added to user's rc.
Example:
```
# >>> Safety >>>
[ -f "$HOME/.safety/.safety_profile" ] && . "$HOME/.safety/.safety_profile"
# <<< Safety <<<
```
"""
lines = (
self.marker_start,
f'[ -f "{self.custom_rc_path}" ] && . "{self.custom_rc_path}"',
self.marker_end,
)
return "\n".join(lines) + "\n"
def _is_configured(self, user_rc_path: Path) -> bool:
"""
Check if the configuration block exists in user's rc file
"""
try:
if not user_rc_path.exists():
return False
content = user_rc_path.read_text()
return self.marker_start in content and self.marker_end in content
except OSError:
logger.info("Failed to read user's rc file")
return False
def _generate_custom_rc_content(self, aliases: List[str]) -> str:
"""
Generate the content for the custom profile with metadata
"""
metadata_lines = self._generate_metadata_content(prepend="#")
aliases_lines = tuple(aliases)
lines = (
(self.marker_start,) + metadata_lines + aliases_lines + (self.marker_end,)
)
return "\n".join(lines) + "\n"
def _ensure_source_line_in_user_rc(self) -> None:
"""
Ensure source line exists in user's rc files
If the source line is not present in the user's rc files, append it.
If the user's rc files do not exist, create them.
"""
source_line = self._generate_user_rc_content()
for user_rc_path in self.user_rc_paths:
if not user_rc_path.exists():
user_rc_path.write_text(source_line)
continue
if not self._is_configured(user_rc_path):
with open(user_rc_path, "a") as f:
f.write(source_line)
def _batch_install_tools(self, tools: List[Tool]) -> bool:
"""
Install aliases for multiple tools
"""
try:
# Generate aliases
aliases = []
for tool in tools:
for binary in tool.binary_names:
alias_def = f'alias {binary}="safety {binary}"'
aliases.append(alias_def)
if not aliases:
return False
# Create safety profile directory if it doesn't exist
self.custom_rc_path.parent.mkdir(parents=True, exist_ok=True)
# Generate new profile content
content = self._generate_custom_rc_content(aliases)
# Backup target files
for f_path in self.user_rc_paths + [self.custom_rc_path]:
self._backup_file(path=f_path)
# Override our custom profile
# TODO: handle exceptions
self.custom_rc_path.write_text(content)
# Ensure source line in user's rc files
self._ensure_source_line_in_user_rc()
return True
except Exception:
logger.exception("Failed to batch install aliases")
return False
def _batch_remove_tools(self, tools: List[Tool]) -> bool:
"""
This will remove all the tools.
NOTE: for now this does not support to remove individual tools.
"""
try:
# Backup target files
for f_path in self.user_rc_paths + [self.custom_rc_path]:
self._backup_file(path=f_path)
for user_rc_path in self.user_rc_paths + self.legacy_user_rc_paths:
if self._is_configured(user_rc_path):
temp_dir = tempfile.gettempdir()
temp_file = Path(temp_dir) / f"{user_rc_path.name}.tmp"
pattern = rf"{self.marker_start}\n.*?\{self.marker_end}\n?"
with open(user_rc_path, "r") as src, open(temp_file, "w") as dst:
content = src.read()
cleaned_content = re.sub(pattern, "", content, flags=re.DOTALL)
dst.write(cleaned_content)
if not temp_file.exists():
logger.info("Temp file is empty or invalid")
return False
shutil.move(str(temp_file), str(user_rc_path))
self.custom_rc_path.unlink(missing_ok=True)
return True
except Exception as e:
logger.exception(f"Failed to batch remove aliases: {e}")
return False
def _install_tool(self, tool: Tool) -> bool:
return self._batch_install_tools([tool])
def _remove_tool(self, tool: Tool) -> bool:
return self._batch_remove_tools([tool])

View File

@@ -0,0 +1,574 @@
import logging
import os
import re
import shutil
from pathlib import Path
from sys import platform
from typing import TYPE_CHECKING, Dict, List
from .base import CommandInterceptor, Tool
from .types import InterceptorType
if TYPE_CHECKING or platform == "win32":
import winreg
from typing import Union
logger = logging.getLogger(__name__)
class AutoRunManager:
"""
Manages Windows Command Processor AutoRun registry entries
"""
REGISTRY_KEY = r"Software\\Microsoft\\Command Processor"
REGISTRY_VALUE = "AutoRun"
def add_script(self, script_path: "Union[str, Path]") -> bool:
"""
Add script to AutoRun, preserving existing commands
"""
script_path = str(script_path)
try:
with self._open_registry_key() as key:
tokens = self._get_current_tokens(key)
if not self._script_exists_in_tokens(tokens, script_path):
new_tokens = (
[script_path, " & "] + tokens if tokens else [script_path]
)
self._set_autorun_value(key, "".join(new_tokens))
return True
except Exception:
logger.info("Failed to add script to AutoRun")
return False
def remove_script(self, script_path: Union[str, Path]) -> bool:
"""
Remove script from AutoRun, preserving other commands
"""
script_path = str(script_path)
try:
with self._open_registry_key() as key:
tokens = self._get_current_tokens(key)
if self._script_exists_in_tokens(tokens, script_path):
cleaned_tokens = self._remove_script_tokens(tokens, script_path)
if cleaned_tokens:
self._set_autorun_value(key, " ".join(cleaned_tokens))
else:
self._delete_autorun_value(key)
return True
except Exception:
logger.info("Failed to remove script from AutoRun")
return False
def get_current_commands(self) -> List[str]:
"""
Get list of current AutoRun commands
"""
try:
with self._open_registry_key() as key:
tokens = self._get_current_tokens(key)
return [
token.strip()
for token in tokens
if not self._is_separator(token) and token.strip()
]
except Exception:
logger.info("Failed to get current AutoRun value")
return []
def _open_registry_key(self):
"""
Context manager for registry key access
"""
try:
return winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
self.REGISTRY_KEY,
0,
winreg.KEY_READ | winreg.KEY_SET_VALUE,
)
except FileNotFoundError:
logger.info("Failed to open registry key")
logger.info("Creating registry key: %s", self.REGISTRY_KEY)
return winreg.CreateKey(winreg.HKEY_CURRENT_USER, self.REGISTRY_KEY)
def _get_current_tokens(self, key) -> List[str]:
"""
Get current AutoRun value as tokens
"""
try:
existing_value, _ = winreg.QueryValueEx(key, self.REGISTRY_VALUE)
return self._tokenize_autorun(existing_value)
except FileNotFoundError:
logger.info("Failed to get current AutoRun value")
return []
def _is_our_script(self, token: str, script_path: str) -> bool:
"""
Check if token is our script (ignoring whitespace)
"""
return token.strip() == script_path
def _is_separator(self, token: str) -> bool:
"""
Check if token is a command separator that can be used to chain
commands in the AutoRun value
"""
return token.strip() in ["&", "&&", "|", "||"]
def _tokenize_autorun(self, autorun_value: str) -> List[str]:
"""
Tokenize AutoRun value preserving commands, separators, and spacing.
Simple character-by-character parsing approach.
"""
if not autorun_value:
return []
tokens = []
current_token = ""
i = 0
while i < len(autorun_value):
char = autorun_value[i]
if char in "&|":
# Save current token if exists
if current_token:
tokens.append(current_token)
current_token = ""
# Handle double operators (&&, ||)
if i + 1 < len(autorun_value) and autorun_value[i + 1] == char:
tokens.append(char + char) # && or ||
i += 2
else:
tokens.append(char) # & or |
i += 1
else:
current_token += char
i += 1
if current_token:
tokens.append(current_token)
return tokens
def _script_exists_in_tokens(self, tokens: List[str], script_path: str) -> bool:
"""
Check if script already exists in token list
"""
return any(token.strip() == script_path for token in tokens)
def _remove_script_tokens(self, tokens: List[str], script_path: str) -> List[str]:
"""
Remove our script and clean up separators
"""
result = []
i = 0
while i < len(tokens):
if self._is_our_script(tokens[i], script_path):
# Skip our script
# Also skip the next separator if it exists
if i + 1 < len(tokens) and self._is_separator(tokens[i + 1]):
i += 2 # Skip script + separator
else:
i += 1 # Skip just the script
else:
result.append(tokens[i].strip())
i += 1
return result
def _set_autorun_value(self, key, value: str):
"""
Set AutoRun registry value
"""
winreg.SetValueEx(key, self.REGISTRY_VALUE, 0, winreg.REG_SZ, value)
def _delete_autorun_value(self, key):
"""
Delete AutoRun registry value
"""
winreg.DeleteValue(key, self.REGISTRY_VALUE)
class WindowsInterceptor(CommandInterceptor):
def __init__(self):
super().__init__(InterceptorType.WINDOWS_BAT)
self.scripts_dir = Path.home() / "AppData" / "Local" / "safety"
# Ensure the scripts directory exists
# This makes sure that if a user is using a sandboxed Python
# installation from the Microsoft Store, the directory is created and
# the .resolve() method works correctly.
self.scripts_dir.mkdir(parents=True, exist_ok=True)
self.scripts_dir = self.scripts_dir.resolve()
self.backup_dir = self.scripts_dir / "backups"
self.backup_win_env_path = self.backup_dir / "path_backup.txt"
self.venv_pwshell_wrapper_path = self.scripts_dir / "venv-wrappers.ps1"
self.venv_cmd_wrapper_path = self.scripts_dir / "venv-wrappers.bat"
# Update these markers could be a breaking change; be careful to handle
# backward compatibility
self.marker_start = ">>> Safety >>>"
self.marker_end = "<<< Safety <<<"
def _backup_path_env(self, path_content: str) -> None:
"""
Backup current PATH to a file
"""
self.backup_dir.mkdir(parents=True, exist_ok=True)
metadata_lines = self._generate_metadata_content(prepend="")
lines = (
(self.marker_start,) + metadata_lines + (path_content,) + (self.marker_end,)
)
content = "\n".join(lines) + "\n"
self.backup_win_env_path.write_text(content)
def _generate_bat_content(self, tool_name: str) -> str:
"""
Generate the content for the bat with metadata
"""
metadata_lines = self._generate_metadata_content(prepend="REM")
no_echo = "@echo off"
wrapper = f"safety {tool_name} %*"
lines = (
(
no_echo,
f"REM {self.marker_start}",
)
+ metadata_lines
+ (wrapper,)
+ (f"REM {self.marker_end}",)
)
return "\n".join(lines) + "\n"
def __generate_cmd_wrapper_content(self, binaries: List[str]) -> str:
"""
Generate the content for the cmd wrapper with metadata
"""
metadata_lines = self._generate_metadata_content(prepend="REM")
no_echo = "@echo off"
wrappers = []
for binary in binaries:
wrapper = f"doskey {binary}={self.scripts_dir / binary}.bat $*"
wrappers.append(wrapper)
comment_top = f"{no_echo}\nREM {self.marker_start}"
comment_bottom = f"REM {self.marker_end}"
lines = (
comment_top,
*metadata_lines,
*wrappers,
comment_bottom,
)
return "\n".join(lines) + "\n"
def __generate_powershell_wrapper_content(self, binaries: List[str]) -> str:
"""
Generate the content for the powershell wrapper with PowerShell functions
"""
metadata_lines = self._generate_metadata_content(prepend="#")
wrappers = []
for binary in binaries:
bat_path = self.scripts_dir / f"{binary}.bat"
wrapper = f"""function {binary} {{
param([Parameter(ValueFromRemainingArguments)]$args)
& "{bat_path}" @args
}}"""
wrappers.append(wrapper)
comment_top = f"# {self.marker_start}"
comment_bottom = f"# {self.marker_end}"
lines = [
comment_top,
*metadata_lines,
*wrappers,
comment_bottom,
]
return "\n".join(lines) + "\n"
def __generate_powershell_sourcing_content(self, script_path: Path) -> str:
"""
Generate the PowerShell sourcing content with Safety markers
"""
lines = [
f"# {self.marker_start}",
f". {script_path}",
f"# {self.marker_end}",
]
return "\n".join(lines) + "\n"
def __get_powershell_profiles(self) -> Dict[str, Path]:
"""
Get the CurrentUserAllHosts profile paths for available PowerShell versions
Returns a dictionary with shell executable as key and profile path as value
"""
profiles = {}
shells = [("powershell.exe", "PowerShell 5.1"), ("pwsh.exe", "PowerShell 7+")]
for shell, shell_name in shells:
try:
import subprocess
# Check if the shell is available
try:
subprocess.run(
[shell, "-Command", "exit"],
capture_output=True,
text=True,
check=False,
)
except FileNotFoundError:
logger.info(f"{shell_name} not found, skipping profile setup")
continue
# Get the CurrentUserAllHosts profile path
cmd = [
shell,
"-Command",
"Get-Variable PROFILE -ValueOnly | Select-Object -ExpandProperty CurrentUserAllHosts",
]
result = subprocess.run(
cmd, capture_output=True, text=True, check=False
)
result_stdout = result.stdout.strip()
if result.returncode == 0 and result_stdout:
profile_path = Path(result_stdout)
# Ensure parent directory exists
profile_path.parent.mkdir(parents=True, exist_ok=True)
# Create the file if it doesn't exist
if not profile_path.exists():
profile_path.touch()
profiles[shell] = profile_path
logger.info(f"Found {shell_name} profile at {profile_path}")
else:
logger.info(
f"Failed to get {shell_name} profile path: {result.stderr.strip()}"
)
except Exception as e:
logger.info(f"Error while getting {shell_name} profile: {str(e)}")
# Fallback to default profile path if no profiles were found
if not profiles:
default_path = (
Path.home() / "Documents" / "WindowsPowerShell" / "profile.ps1"
)
default_path.parent.mkdir(parents=True, exist_ok=True)
if not default_path.exists():
default_path.touch()
profiles["fallback"] = default_path
logger.info(f"Using fallback profile at {default_path}")
return profiles
def _install_venv_wrappers(self, binaries: List[str]):
"""
Install specific wrappers for virtualenv support on Windows
"""
# Refresh scripts content
# CMD wrappers
cmd_wrapper = self.__generate_cmd_wrapper_content(binaries)
self.venv_cmd_wrapper_path.write_text(cmd_wrapper)
# PowerShell wrappers
powershell_wrapper = self.__generate_powershell_wrapper_content(binaries)
self.venv_pwshell_wrapper_path.write_text(powershell_wrapper)
# Link CMD wrapper to Autorun
autorun_manager = AutoRunManager()
autorun_manager.add_script(self.venv_cmd_wrapper_path)
# Link Powershell wrapper to Powershell PROFILEs
profiles = self.__get_powershell_profiles()
pwshell_source = self.__generate_powershell_sourcing_content(
self.venv_pwshell_wrapper_path
)
for _, profile_path in profiles.items():
try:
# Read current content or create empty string if file doesn't exist yet
try:
profile_content = profile_path.read_text()
except FileNotFoundError:
profile_path.parent.mkdir(parents=True, exist_ok=True)
profile_content = ""
# Add sourcing command if not already present
if self.marker_start not in profile_content:
if profile_content and not profile_content.endswith("\n"):
profile_content += "\n"
profile_content += pwshell_source
profile_path.write_text(profile_content)
logger.info(f"Added PowerShell wrapper to {profile_path}")
except Exception as e:
logger.info(
f"Failed to update PowerShell profile at {profile_path}: {str(e)}"
)
def _remove_venv_wrappers(self):
"""
Remove specific wrappers for virtualenv support on Windows.
This is an indempotent operation.
"""
# For CMD
autorun_manager = AutoRunManager()
autorun_manager.remove_script(self.venv_cmd_wrapper_path)
# For PowerShell
# Remove Powershell wrapper from all PowerShell profiles
profiles = self.__get_powershell_profiles()
for _, profile_path in profiles.items():
try:
if profile_path.exists():
profile_content = profile_path.read_text()
if self.marker_start not in profile_content:
logger.info(f"PowerShell wrapper not found in {profile_path}")
continue
# Look for our sourcing line and the comment block we added
# Remove the entire block including comments
lines = profile_content.splitlines()
new_lines = []
skip_block = False
for line in lines:
if self.marker_start in line:
skip_block = True
continue
if skip_block:
if self.marker_end in line:
skip_block = False
continue
new_lines.append(line)
new_content = "\n".join(new_lines)
new_content = re.sub(r"\n{3,}", "\n\n", new_content)
profile_path.write_text(new_content)
logger.info(f"Removed PowerShell wrapper from {profile_path}")
except Exception as e:
logger.info(
f"Failed to remove PowerShell wrapper from {profile_path}: {str(e)}"
)
def _batch_install_tools(self, tools: List[Tool]) -> bool:
"""
Install interceptors for multiple tools at once
"""
try:
wrappers = []
for tool in tools:
for binary in tool.binary_names:
# TODO: Switch to binary once we support safety pip3, etc.
wrapper = self._generate_bat_content(tool.name)
wrappers.append((binary, wrapper))
if not wrappers:
return False
# Create safety directory if it doesn't exist
self.scripts_dir.mkdir(parents=True, exist_ok=True)
for binary, wrapper in wrappers:
wrapper_path = self.scripts_dir / f"{binary}.bat"
wrapper_path.write_text(wrapper)
# Virtualenv environment wrappers
all_binaries = [binary for tool in tools for binary in tool.binary_names]
self._install_venv_wrappers(binaries=all_binaries)
# Add scripts directory to PATH if needed
self._update_path()
return True
except Exception as e:
logger.info("Failed to batch install tools: %s", e)
return False
def _batch_remove_tools(self, tools: List[Tool]) -> bool:
"""
Remove interceptors for multiple tools at once.
Note: We don't support removing specific tools yet,
so we remove all tools.
"""
try:
self._update_path(remove=True)
if self.scripts_dir.exists():
shutil.rmtree(self.scripts_dir)
self._remove_venv_wrappers()
return True
except Exception as e:
logger.info("Failed to batch remove tools: %s", e)
return False
def _update_path(self, remove: bool = False) -> bool:
"""
Update Windows PATH environment variable
"""
try:
with winreg.OpenKey(
winreg.HKEY_CURRENT_USER, "Environment", 0, winreg.KEY_ALL_ACCESS
) as key:
# Get current PATH value
try:
path_val = winreg.QueryValueEx(key, "PATH")[0]
self._backup_path_env(path_content=path_val)
except FileNotFoundError:
path_val = ""
# Convert to Path objects
paths = [Path(p) for p in path_val.split(os.pathsep) if p]
if remove:
if self.scripts_dir in paths:
paths.remove(self.scripts_dir)
new_path = os.pathsep.join(str(p) for p in paths)
winreg.SetValueEx(
key, "PATH", 0, winreg.REG_EXPAND_SZ, new_path
)
else:
if self.scripts_dir not in paths:
paths.insert(0, self.scripts_dir) # Add to beginning
new_path_val = os.pathsep.join(str(p) for p in paths)
winreg.SetValueEx(
key, "PATH", 0, winreg.REG_EXPAND_SZ, new_path_val
)
return True
except Exception as e:
logger.info("Failed to update PATH: %s", e)
return False

View File

@@ -0,0 +1,115 @@
from typing import Any, Dict, List, Tuple
import os.path
from pathlib import Path
from typing import Optional
from safety.constants import USER_CONFIG_DIR
from safety.tool.utils import (
NpmConfigurator,
NpmProjectConfigurator,
PipConfigurator,
PipRequirementsConfigurator,
PoetryConfigurator,
PoetryPyprojectConfigurator,
UvConfigurator,
UvPyprojectConfigurator,
is_os_supported,
)
from safety_schemas.models.events.types import ToolType
from .interceptors import create_interceptor
import logging
logger = logging.getLogger(__name__)
def find_local_tool_files(directory: Path) -> List[Path]:
configurators = [
PipRequirementsConfigurator(),
PoetryPyprojectConfigurator(),
UvPyprojectConfigurator(),
NpmProjectConfigurator(),
]
results = []
for file_name in os.listdir(directory):
if os.path.isfile(file_name):
file = Path(file_name)
for configurator in configurators:
if configurator.is_supported(file):
results.append(file)
return results
def configure_system(org_slug: Optional[str]) -> List[Tuple[ToolType, Optional[Path]]]:
configurators: List[Tuple[ToolType, Any, Dict[str, Any]]] = [
(ToolType.PIP, PipConfigurator(), {"org_slug": org_slug}),
(ToolType.POETRY, PoetryConfigurator(), {"org_slug": org_slug}),
(ToolType.UV, UvConfigurator(), {"org_slug": org_slug}),
(ToolType.NPM, NpmConfigurator(), {"org_slug": org_slug}),
]
results = []
for tool_type, configurator, kwargs in configurators:
result = configurator.configure(**kwargs)
results.append((tool_type, result))
return results
def reset_system():
configurators = [
PipConfigurator(),
PoetryConfigurator(),
UvConfigurator(),
NpmConfigurator(),
]
for configurator in configurators:
configurator.reset()
def configure_alias() -> Optional[List[Tuple[ToolType, Optional[Path]]]]:
if not is_os_supported():
logger.warning("OS not supported for alias configuration.")
return None
interceptor = create_interceptor()
result = interceptor.install_interceptors()
if result:
config = Path(f"{USER_CONFIG_DIR}/.safety_profile")
return [
(ToolType.PIP, config),
(ToolType.POETRY, config),
(ToolType.UV, config),
(ToolType.NPM, config),
]
return [
(ToolType.PIP, None),
(ToolType.POETRY, None),
(ToolType.UV, None),
(ToolType.NPM, None),
]
def configure_local_directory(
directory: Path, org_slug: Optional[str], project_id: Optional[str]
):
configurators = [
PipRequirementsConfigurator(),
PoetryPyprojectConfigurator(),
UvPyprojectConfigurator(),
NpmProjectConfigurator(),
]
for file_name in os.listdir(directory):
if os.path.isfile(file_name):
file = Path(file_name)
for configurator in configurators:
if configurator.is_supported(file):
configurator.configure(file, org_slug, project_id)

View File

@@ -0,0 +1,191 @@
from typing import Any, List, Protocol, Tuple, Dict, Literal, runtime_checkable
import typer
from rich.padding import Padding
from .base import EnvironmentDiffTracker
from safety.console import main_console as console
from safety.init.render import render_header, progressive_print
from safety.models import ToolResult
import logging
from .intents import ToolIntentionType, CommandToolIntention
from .environment_diff import PackageLocation
logger = logging.getLogger(__name__)
@runtime_checkable
class AuditableCommand(Protocol):
"""
Protocol defining the contract for classes that can be audited for packages.
"""
@property
def _diff_tracker(self) -> "EnvironmentDiffTracker":
"""
Provides package tracking functionality.
"""
...
def get_ecosystem(self) -> Literal["pypi", "npmjs"]:
"""
Return the ecosystem used by the command implementation.
"""
...
class InstallationAuditMixin:
"""
Mixin providing installation audit functionality for command classes.
This mixin can be used by any command class that needs to audit
installation and show warnings.
Classes using this mixin should conform to the AuditableCommand protocol.
"""
def render_installation_warnings(
self, ctx: typer.Context, packages_audit: Dict[str, Any]
):
"""
Render installation warnings based on package audit results.
Args:
ctx: The typer context
packages_audit: pre-fetched audit data
"""
warning_messages = []
for audited_package in packages_audit.get("audit", {}).get("packages", []):
vulnerabilities = audited_package.get("vulnerabilities", {})
critical_vulnerabilities = vulnerabilities.get("critical", 0)
total_vulnerabilities = 0
for count in vulnerabilities.values():
total_vulnerabilities += count
if total_vulnerabilities == 0:
continue
warning_message = f"[[yellow]Warning[/yellow]] {audited_package.get('package_specifier')} contains {total_vulnerabilities} {'vulnerabilities' if total_vulnerabilities != 1 else 'vulnerability'}"
if critical_vulnerabilities > 0:
warning_message += f", including {critical_vulnerabilities} critical severity {'vulnerabilities' if critical_vulnerabilities != 1 else 'vulnerability'}"
warning_message += "."
warning_messages.append(warning_message)
if len(warning_messages) > 0:
console.print()
render_header(" Safety Report")
progressive_print(warning_messages)
console.line()
def render_package_details(self: "AuditableCommand", packages: List[str]):
"""
Render details for installed packages.
"""
if "npmjs" == self.get_ecosystem():
url = "https://getsafety.com/"
failed = ", ".join(packages)
console.line()
console.print(
Padding(
f"Learn more about {failed} in [link]{url}[/link]",
(0, 0, 0, 1),
),
emoji=True,
)
else:
for package_name in packages:
console.print(
Padding(
f"Learn more: [link]https://data.safetycli.com/packages/pypi/{package_name}/[/link]",
(0, 0, 0, 1),
),
emoji=True,
)
def audit_packages(
self: "AuditableCommand", ctx: typer.Context
) -> Tuple[Dict[str, Any], Dict[PackageLocation, str]]:
"""
Audit packages based on environment diff tracking.
Override this method in your command class if needed.
Args:
ctx: The typer context
Returns:
Dict containing audit results
"""
try:
diff_tracker = getattr(self, "_diff_tracker", None)
if diff_tracker and hasattr(diff_tracker, "get_diff"):
added, _, updated = diff_tracker.get_diff()
packages: Dict[PackageLocation, str] = {**added, **updated}
if (
hasattr(ctx.obj, "auth")
and hasattr(ctx.obj.auth, "client")
and packages
):
ecosystem = self.get_ecosystem()
eq_exp = "@" if ecosystem == "npmjs" else "=="
return (
ctx.obj.auth.client.audit_packages(
[
f"{package.name}{eq_exp}{version[-1] if isinstance(version, tuple) else version}"
for (package, version) in packages.items()
],
ecosystem,
),
packages,
)
except Exception:
logger.debug("Audit API failed with error", exc_info=True)
# Always return a dict to satisfy the return type
return dict(), dict()
def handle_installation_audit(self, ctx: typer.Context, result: ToolResult):
"""
Handle installation audit and rendering warnings/details.
This is an explicit method that can be called from a command's after method.
Usage example:
def after(self, ctx, result):
super().after(ctx, result)
self.handle_installation_audit(ctx, result)
Args:
ctx: The typer context
result: The tool result
"""
if not isinstance(self, AuditableCommand):
raise TypeError(
"handle_installation_audit can only be called on instances of AuditableCommand"
)
audit_result, packages = self.audit_packages(ctx)
self.render_installation_warnings(ctx, audit_result)
if not result.process or result.process.returncode != 0:
package_names = {pl.name for pl in packages}
# Access _intention safely to keep the protocol minimal and satisfy type checkers
intent = getattr(self, "_intention", None)
if isinstance(intent, CommandToolIntention):
command_intent: CommandToolIntention = intent
if (
command_intent.intention_type
is not ToolIntentionType.REMOVE_PACKAGE
and command_intent.packages
):
for dep in command_intent.packages:
if dep.name:
package_names.add(dep.name)
if package_names:
self.render_package_details(sorted(package_names))

View File

@@ -0,0 +1,5 @@
from .main import Npm
__all__ = [
"Npm",
]

View File

@@ -0,0 +1,173 @@
from typing import TYPE_CHECKING, List, Optional, Dict, Any, Literal
from typing import Tuple
import logging
import typer
from safety.models import ToolResult
from .parser import NpmParser
from ..base import BaseCommand
from safety_schemas.models.events.types import ToolType
from ..environment_diff import EnvironmentDiffTracker, NpmEnvironmentDiffTracker
from ..mixins import InstallationAuditMixin
from ..constants import TOP_NPMJS_PACKAGES
from ..auth import build_index_url
import json
if TYPE_CHECKING:
from ..environment_diff import EnvironmentDiffTracker
logger = logging.getLogger(__name__)
class NpmCommand(BaseCommand):
"""
Main class for hooks into npm commands.
"""
def get_tool_type(self) -> ToolType:
return ToolType.NPM
def get_command_name(self) -> List[str]:
return ["npm"]
def get_ecosystem(self) -> Literal["pypi", "npmjs"]:
return "npmjs"
def get_package_list_command(self) -> List[str]:
return [*self.get_command_name(), "list", "--json", "--all", "-l"]
def _flatten_packages(self, dependencies: Dict[str, Any]) -> List[Dict[str, str]]:
"""
Flatten npm list --json --all -l output into a list of package dictionaries with file paths.
Args:
dependencies: The root dependencies dictionary from JSON output from npm list --json --all -l
Returns:
List of dictionaries with name, version, and location keys
"""
result = []
def traverse(dependencies: Dict[str, Any]):
if not dependencies:
return
for name, info in dependencies.items():
result.append(
{
"name": name,
"version": info.get("version", ""),
"location": info.get("path", ""),
}
)
# Recursively process nested dependencies
if "dependencies" in info:
traverse(info["dependencies"])
traverse(dependencies)
return result
def parse_package_list_output(self, output: str) -> List[Dict[str, Any]]:
"""
Handle the output of the npm list command.
Args:
output: Command output
Returns:
List[Dict[str, Any]]: List of package dictionaries
"""
try:
result = json.loads(output)
except json.JSONDecodeError:
# Log error and return empty list
logger.exception(f"Error parsing package list output: {output[:100]}...")
return []
return self._flatten_packages(result.get("dependencies", {}))
def get_diff_tracker(self) -> "EnvironmentDiffTracker":
return NpmEnvironmentDiffTracker()
def _get_typosquatting_reference_packages(self) -> Tuple[str]:
return TOP_NPMJS_PACKAGES
@classmethod
def from_args(cls, args: List[str], **kwargs):
parser = NpmParser()
if intention := parser.parse(args):
kwargs["intention"] = intention
if intention.modifies_packages():
return AuditableNpmCommand(args, **kwargs)
if intention.queries_packages():
return SearchCommand(args, **kwargs)
return NpmCommand(args, **kwargs)
class NpmIndexEnvMixin:
"""
Mixin to inject Safety's default index URL into npm's environment.
Expects implementers to define `self._index_url` (Optional[str]).
"""
def env(self, ctx: typer.Context) -> dict:
env = super().env(ctx) # pyright: ignore[reportAttributeAccessIssue]
default_index_url = build_index_url(
ctx, getattr(self, "_index_url", None), "npm"
)
env["NPM_CONFIG_REGISTRY"] = default_index_url
return env
class SearchCommand(NpmIndexEnvMixin, NpmCommand):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._index_url = None
class AuditableNpmCommand(NpmIndexEnvMixin, NpmCommand, InstallationAuditMixin):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._index_url = None
def before(self, ctx: typer.Context):
super().before(ctx)
args: List[Optional[str]] = self._args.copy() # type: ignore
if self._intention:
if registry_opt := self._intention.options.get(
"registry"
) or self._intention.options.get("r"):
registry_value = registry_opt["value"]
if registry_value and registry_value.startswith(
"https://pkgs.safetycli.com"
):
self._index_url = registry_value
arg_index = registry_opt["arg_index"]
value_index = registry_opt["value_index"]
if (
arg_index
and value_index
and arg_index < len(args)
and value_index < len(args)
):
args[arg_index] = None
args[value_index] = None
self._args = [arg for arg in args if arg is not None]
def after(self, ctx: typer.Context, result: ToolResult):
super().after(ctx, result)
self.handle_installation_audit(ctx, result)

View File

@@ -0,0 +1,171 @@
import logging
import shutil
import subprocess
from pathlib import Path
from typing import Optional
import typer
from rich.console import Console
from safety.tool.constants import (
NPMJS_PUBLIC_REPOSITORY_URL,
NPMJS_ORGANIZATION_REPOSITORY_URL,
NPMJS_PROJECT_REPOSITORY_URL,
)
from safety.tool.resolver import get_unwrapped_command
from safety.utils.pyapp_utils import get_path, get_env
from safety.console import main_console
from safety.tool.auth import build_index_url
logger = logging.getLogger(__name__)
class Npm:
@classmethod
def is_installed(cls) -> bool:
"""
Checks if the NPM program is installed
Returns:
True if NPM is installed on system, or false otherwise
"""
return shutil.which("npm", path=get_path()) is not None
@classmethod
def configure_project(
cls,
project_path: Path,
org_slug: Optional[str],
project_id: Optional[str],
console: Console = main_console,
) -> Optional[Path]:
"""
Configures Safety index url for specified npmrc file.
Args:
file (Path): Path to npmrc file.
org_slug (str): Organization slug.
project_id (str): Project identifier.
console (Console): Console instance.
"""
if not cls.is_installed():
logger.error("NPM is not installed.")
return None
repository_url = (
NPMJS_PROJECT_REPOSITORY_URL.format(org_slug, project_id)
if project_id and org_slug
else (
NPMJS_ORGANIZATION_REPOSITORY_URL.format(org_slug)
if org_slug
else NPMJS_PUBLIC_REPOSITORY_URL
)
)
project_root = project_path.resolve()
if not project_root.is_dir():
project_root = project_path.parent
result = subprocess.run(
[
get_unwrapped_command(name="npm"),
"config",
"set",
"registry",
repository_url,
"--location",
"project",
],
capture_output=True,
cwd=project_root,
env=get_env(),
)
if result.returncode != 0:
logger.error(
f"Failed to configure NPM project settings: {result.stderr.decode('utf-8')}"
)
return None
return project_root
@classmethod
def configure_system(
cls, org_slug: Optional[str], console: Console = main_console
) -> Optional[Path]:
"""
Configures NPM system to use to Safety index url.
"""
if not cls.is_installed():
logger.error("NPM is not installed.")
return None
try:
repository_url = (
NPMJS_ORGANIZATION_REPOSITORY_URL.format(org_slug)
if org_slug
else NPMJS_PUBLIC_REPOSITORY_URL
)
result = subprocess.run(
[
get_unwrapped_command(name="npm"),
"config",
"set",
"-g",
"registry",
repository_url,
],
capture_output=True,
env=get_env(),
)
if result.returncode != 0:
logger.error(
f"Failed to configure NPM global settings: {result.stderr.decode('utf-8')}"
)
return None
query_config_result = subprocess.run(
[
get_unwrapped_command(name="npm"),
"config",
"get",
"globalconfig",
],
capture_output=True,
env=get_env(),
)
config_file_path = query_config_result.stdout.decode("utf-8").strip()
if config_file_path:
return Path(config_file_path)
logger.error("Failed to match the config file path written by NPM.")
return Path()
except Exception:
logger.exception("Failed to configure NPM global settings.")
return None
@classmethod
def reset_system(cls, console: Console = main_console):
# TODO: Move this logic and implement it in a more robust way
try:
subprocess.run(
[
get_unwrapped_command(name="npm"),
"config",
"set",
"-g",
"registry",
],
capture_output=True,
env=get_env(),
)
except Exception:
console.print("Failed to reset NPM global settings.")
@classmethod
def build_index_url(cls, ctx: typer.Context, index_url: Optional[str]) -> str:
return build_index_url(ctx, index_url, "npm")

View File

@@ -0,0 +1,195 @@
from typing import Dict
from ..base import ToolCommandLineParser
from ..intents import ToolIntentionType
from typing import Union, Set, Optional, Mapping
from ..intents import Dependency
ADD_PACKAGE_ALIASES = [
"install",
"add",
"i",
"in",
"ins",
"inst",
"insta",
"instal",
"isnt",
"isnta",
"isntal",
"isntall",
]
REMOVE_PACKAGE_ALIASES = [
"uninstall",
"unlink",
"remove",
"rm",
"r",
"un",
]
UPDATE_PACKAGE_ALIASES = [
"update",
"up",
"upgrade",
"udpate",
]
SYNC_PACKAGES_ALIASES = [
"ci",
"clean-install",
"ic",
"install-clean",
"isntall-clean",
]
LIST_PACKAGES_ALIASES = [
"list",
"ls",
"ll",
"la",
]
SEARCH_PACKAGES_ALIASES = [
# Via view
"view",
"info",
"show",
"v",
# Via search
"search",
"find",
"s",
"se",
]
INIT_PROJECT_ALIASES = [
"init",
"create",
]
class NpmParser(ToolCommandLineParser):
def get_tool_name(self) -> str:
return "npm"
def get_command_hierarchy(self) -> Mapping[str, Union[ToolIntentionType, Mapping]]:
"""
Context for command hierarchy parsing
"""
alias_map = {
ToolIntentionType.ADD_PACKAGE: ADD_PACKAGE_ALIASES,
ToolIntentionType.REMOVE_PACKAGE: REMOVE_PACKAGE_ALIASES,
ToolIntentionType.UPDATE_PACKAGE: UPDATE_PACKAGE_ALIASES,
ToolIntentionType.SYNC_PACKAGES: SYNC_PACKAGES_ALIASES,
ToolIntentionType.SEARCH_PACKAGES: SEARCH_PACKAGES_ALIASES,
ToolIntentionType.LIST_PACKAGES: LIST_PACKAGES_ALIASES,
ToolIntentionType.INIT_PROJECT: INIT_PROJECT_ALIASES,
}
hierarchy = {
alias.lower().strip(): intention
for intention, aliases in alias_map.items()
for alias in aliases
}
return hierarchy
def get_known_flags(self) -> Dict[str, Set[str]]:
"""
Define flags that DON'T take values to avoid consuming packages
"""
GLOBAL_FLAGS = {
"S",
"save",
"no-save",
"save-prod",
"save-dev",
"save-optional",
"save-peer",
"save-bundle",
"g",
"global",
"workspaces",
"include-workspace-root",
"install-links",
"json",
"no-color",
"parseable",
"p",
"no-description",
"prefer-offline",
"offline",
}
OTHER_FLAGS = {
"E",
"save-exact",
"legacy-bundling",
"global-style",
"strict-peer-deps",
"prefer-dedupe",
"no-package-lock",
"package-lock-only",
"foreground-scripts",
"ignore-scripts",
"no-audit",
"no-bin-links",
"no-fund",
"dry-run",
}
return {
# We don't need to differentiate between flags for different commands
"global": GLOBAL_FLAGS | OTHER_FLAGS,
}
def _parse_package_spec(
self, spec_str: str, arg_index: int
) -> Optional[Dependency]:
"""
Parse npm registry specs like "react", "@types/node@^20",
and aliases like "alias@npm:@sentry/node@7".
Skips non-registry (git/url/path).
"""
import re
s = spec_str.strip()
REGISTRY_RE = re.compile(
r"""^(?P<name>@[^/\s]+/[^@\s]+|[A-Za-z0-9._-]+)(?:@(?P<constraint>.+))?$"""
)
ALIAS_RE = re.compile(
r"""^(?P<alias>@?[^@\s/]+(?:/[^@\s/]+)?)@npm:(?P<target>.+)$"""
)
def mk(name: str, constraint: Optional[str]) -> Dependency:
dep = Dependency(
name=name.lower(),
version_constraint=(constraint or None),
arg_index=arg_index,
original_text=spec_str,
)
return dep
# alias form
m = ALIAS_RE.match(s)
if m:
alias = m.group("alias")
target = m.group("target").strip()
rm = REGISTRY_RE.match(target)
if rm:
return mk(alias, rm.group("constraint"))
# out-of-scope target
return None
# plain registry form
m = REGISTRY_RE.match(s)
if m:
return mk(m.group("name"), m.group("constraint"))
return None

View File

@@ -0,0 +1,3 @@
from .main import Pip
__all__ = ["Pip"]

View File

@@ -0,0 +1,116 @@
from typing import TYPE_CHECKING, List, Optional
import logging
import typer
from safety.models import ToolResult
from .parser import PipParser
from ..base import BaseCommand
from safety_schemas.models.events.types import ToolType
from ..environment_diff import EnvironmentDiffTracker, PipEnvironmentDiffTracker
from ..mixins import InstallationAuditMixin
from .main import Pip
if TYPE_CHECKING:
from ..environment_diff import EnvironmentDiffTracker
logger = logging.getLogger(__name__)
class PipCommand(BaseCommand):
"""
Main class for hooks into pip commands.
"""
def get_tool_type(self) -> ToolType:
return ToolType.PIP
def get_command_name(self) -> List[str]:
"""
This uses command alias if available, with this we support
pip3.13, pip3.12, etc.
"""
cmd_name = ["pip"]
if self._command_alias_used:
cmd_name = [self._command_alias_used]
return cmd_name
def get_diff_tracker(self) -> "EnvironmentDiffTracker":
return PipEnvironmentDiffTracker()
@classmethod
def from_args(cls, args: List[str], **kwargs):
parser = PipParser()
if intention := parser.parse(args):
kwargs["intention"] = intention
if intention.modifies_packages():
return AuditablePipCommand(args, **kwargs)
if intention.queries_packages():
return SearchCommand(args, **kwargs)
return PipCommand(args, **kwargs)
class PipIndexEnvMixin:
"""
Mixin to inject Safety's default index URL into pip's environment.
Expects implementers to define `self._index_url` (Optional[str]).
"""
def env(self, ctx: typer.Context) -> dict:
env = super().env(ctx) # pyright: ignore[reportAttributeAccessIssue]
default_index_url = Pip.build_index_url(ctx, getattr(self, "_index_url", None))
env["PIP_INDEX_URL"] = default_index_url
env["PIP_PYPI_URL"] = default_index_url
return env
class SearchCommand(PipIndexEnvMixin, PipCommand):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._index_url = None
class AuditablePipCommand(PipIndexEnvMixin, PipCommand, InstallationAuditMixin):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._index_url = None
def before(self, ctx: typer.Context):
super().before(ctx)
args: List[Optional[str]] = self._args.copy() # type: ignore
if self._intention:
if index_opt := self._intention.options.get(
"index-url"
) or self._intention.options.get("i"):
index_value = index_opt["value"]
if index_value and index_value.startswith("https://pkgs.safetycli.com"):
self._index_url = index_value
arg_index = index_opt["arg_index"]
value_index = index_opt["value_index"]
if (
arg_index
and value_index
and arg_index < len(args)
and value_index < len(args)
):
args[arg_index] = None
args[value_index] = None
self._args = [arg for arg in args if arg is not None]
def after(self, ctx: typer.Context, result: ToolResult):
super().after(ctx, result)
self.handle_installation_audit(ctx, result)

View File

@@ -0,0 +1,150 @@
import logging
import re
import shutil
import subprocess
from pathlib import Path
from typing import Optional
import typer
from rich.console import Console
from safety.tool.constants import (
PYPI_PUBLIC_REPOSITORY_URL,
PYPI_ORGANIZATION_REPOSITORY_URL,
PYPI_PROJECT_REPOSITORY_URL,
)
from safety.tool.resolver import get_unwrapped_command
from safety.utils.pyapp_utils import get_path, get_env
from safety.console import main_console
from safety.tool.auth import build_index_url
from ...encoding import detect_encoding
logger = logging.getLogger(__name__)
class Pip:
@classmethod
def is_installed(cls) -> bool:
"""
Checks if the PIP program is installed
Returns:
True if PIP is installed on system, or false otherwise
"""
return shutil.which("pip", path=get_path()) is not None
@classmethod
def configure_requirements(
cls,
file: Path,
org_slug: Optional[str],
project_id: Optional[str],
console: Console = main_console,
) -> Optional[Path]:
"""
Configures Safety index url for specified requirements file.
Args:
file (Path): Path to requirements.txt file.
org_slug (str): Organization slug.
project_id (str): Project identifier.
console (Console): Console instance.
"""
with open(file, "r+", encoding=detect_encoding(file)) as f:
content = f.read()
repository_url = (
PYPI_PROJECT_REPOSITORY_URL.format(org_slug, project_id)
if project_id and org_slug
else (
PYPI_ORGANIZATION_REPOSITORY_URL.format(org_slug)
if org_slug
else PYPI_PUBLIC_REPOSITORY_URL
)
)
index_config = f"-i {repository_url}\n"
if content.find(index_config) == -1:
f.seek(0)
f.write(index_config + content)
logger.info(f"Configured {file} file")
return file
else:
logger.info(f"{file} is already configured. Skipping.")
return None
@classmethod
def configure_system(
cls, org_slug: Optional[str], console: Console = main_console
) -> Optional[Path]:
"""
Configures PIP system to use to Safety index url.
"""
try:
repository_url = (
PYPI_ORGANIZATION_REPOSITORY_URL.format(org_slug)
if org_slug
else PYPI_PUBLIC_REPOSITORY_URL
)
result = subprocess.run(
[
get_unwrapped_command(name="pip"),
"config",
"--user",
"set",
"global.index-url",
repository_url,
],
capture_output=True,
env=get_env(),
)
if result.returncode != 0:
logger.error(
f"Failed to configure PIP global settings: {result.stderr.decode('utf-8')}"
)
return None
output = result.stdout.decode("utf-8")
match = re.search(r"Writing to (.+)", output)
if match:
config_file_path = match.group(1).strip()
return Path(config_file_path)
logger.error("Failed to match the config file path written by pip.")
return Path()
except Exception:
logger.exception("Failed to configure PIP global settings.")
return None
@classmethod
def reset_system(cls, console: Console = main_console):
# TODO: Move this logic and implement it in a more robust way
try:
subprocess.run(
[
get_unwrapped_command(name="pip"),
"config",
"--user",
"unset",
"global.index-url",
],
capture_output=True,
env=get_env(),
)
except Exception:
console.print("Failed to reset PIP global settings.")
@classmethod
def default_index_url(cls) -> str:
return "https://pypi.org/simple/"
@classmethod
def build_index_url(cls, ctx: typer.Context, index_url: Optional[str]) -> str:
return build_index_url(ctx, index_url, "pypi")

View File

@@ -0,0 +1,105 @@
from typing import Dict
from ..base import ToolCommandLineParser
from ..intents import ToolIntentionType
from typing import Union, Set
class PipParser(ToolCommandLineParser):
def get_tool_name(self) -> str:
return "pip"
def get_command_hierarchy(self) -> Dict[str, Union[ToolIntentionType, Dict]]:
"""
Context for command hierarchy parsing
"""
return {
"install": ToolIntentionType.ADD_PACKAGE,
"uninstall": ToolIntentionType.REMOVE_PACKAGE,
"download": ToolIntentionType.DOWNLOAD_PACKAGE,
"wheel": ToolIntentionType.DOWNLOAD_PACKAGE,
"query": ToolIntentionType.SEARCH_PACKAGES,
"index": {
"versions": ToolIntentionType.SEARCH_PACKAGES,
},
}
def get_known_flags(self) -> Dict[str, Set[str]]:
"""
Define flags that DON'T take values to avoid consuming packages
"""
return {
# Global flags (available for all commands)
"global": {
"help",
"h",
"debug",
"isolated",
"require-virtualenv",
"verbose",
"v",
"version",
"V",
"quiet",
"q",
"no-input",
"no-cache-dir",
"disable-pip-version-check",
"no-color",
# Index specific
"no-index",
},
# install-specific flags
"install": {
"no-deps",
"pre",
"dry-run",
"user",
"upgrade",
"U",
"force-reinstall",
"ignore-installed",
"I",
"ignore-requires-python",
"no-build-isolation",
"use-pep517",
"no-use-pep517",
"check-build-dependencies",
"break-system-packages",
"compile",
"no-compile",
"no-warn-script-location",
"no-warn-conflicts",
"prefer-binary",
"require-hashes",
"no-clean",
},
# uninstall-specific flags
"uninstall": {
"yes",
"y",
"break-system-packages",
},
# download-specific flags
"download": {
"no-deps",
"no-binary",
"only-binary",
"prefer-binary",
"pre",
"require-hashes",
"no-build-isolation",
"use-pep517",
"no-use-pep517",
"check-build-dependencies",
"ignore-requires-python",
"no-clean",
},
"index.versions": {
"ignore-requires-python",
"pre",
"json",
"no-index",
},
}

View File

@@ -0,0 +1,5 @@
from .main import Poetry
__all__ = [
"Poetry",
]

View File

@@ -0,0 +1,178 @@
from pathlib import Path
import sys
from typing import TYPE_CHECKING, List, Optional, Tuple
import logging
import typer
from safety.tool.utils import PoetryPyprojectConfigurator
from .constants import MSG_SAFETY_SOURCE_ADDED, MSG_SAFETY_SOURCE_NOT_ADDED
from .parser import PoetryParser
from ..auth import index_credentials
from ..base import BaseCommand, ToolIntentionType
from ..mixins import InstallationAuditMixin
from ..environment_diff import EnvironmentDiffTracker, PipEnvironmentDiffTracker
from safety_schemas.models.events.types import ToolType
from safety.console import main_console as console
from safety.models import ToolResult
if TYPE_CHECKING:
from ..environment_diff import EnvironmentDiffTracker
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
logger = logging.getLogger(__name__)
class PoetryCommand(BaseCommand):
"""
Main class for hooks into poetry commands.
"""
def get_tool_type(self) -> ToolType:
return ToolType.POETRY
def get_command_name(self) -> List[str]:
return ["poetry"]
def get_diff_tracker(self) -> "EnvironmentDiffTracker":
# pip diff tracker will be enough for poetry
return PipEnvironmentDiffTracker()
def get_package_list_command(self) -> List[str]:
"""
Get the package list of a poetry virtual environment.
This implementation uses pip to list packages.
Returns:
List[str]: Command to list packages in JSON format
"""
return ["poetry", "run", "pip", "list", "-v", "--format=json"]
def _has_safety_source_in_pyproject(self) -> bool:
"""
Check if 'safety' source exists in pyproject.toml
"""
if not Path("pyproject.toml").exists():
return False
try:
# Parse the TOML file
with open("pyproject.toml", "rb") as f:
pyproject = tomllib.load(f)
poetry_config = pyproject.get("tool", {}).get("poetry", {})
sources = poetry_config.get("source", [])
if isinstance(sources, dict):
return "safety" in sources
else:
return any(source.get("name") == "safety" for source in sources)
except (FileNotFoundError, KeyError, tomllib.TOMLDecodeError):
return False
def before(self, ctx: typer.Context):
super().before(ctx)
if self._intention and self._intention.intention_type in [
ToolIntentionType.SYNC_PACKAGES,
ToolIntentionType.ADD_PACKAGE,
]:
if not self._has_safety_source_in_pyproject():
org_slug = None
try:
data = ctx.obj.auth.client.initialize()
org_slug = data.get("organization-data", {}).get("slug")
except Exception:
logger.exception(
"Unable to pull the org slug from the initialize endpoint."
)
try:
configurator = PoetryPyprojectConfigurator()
prj_slug = ctx.obj.project.id if ctx.obj.project else None
if configurator.configure(
Path("pyproject.toml"), org_slug, prj_slug
):
console.print(
MSG_SAFETY_SOURCE_ADDED,
)
except Exception:
logger.exception("Unable to configure the pyproject.toml file.")
console.print(
MSG_SAFETY_SOURCE_NOT_ADDED,
)
def env(self, ctx: typer.Context) -> dict:
env = super().env(ctx)
env.update(
{
"POETRY_HTTP_BASIC_SAFETY_USERNAME": "user",
"POETRY_HTTP_BASIC_SAFETY_PASSWORD": index_credentials(ctx),
}
)
return env
@classmethod
def from_args(cls, args: List[str], **kwargs):
parser = PoetryParser()
if intention := parser.parse(args):
kwargs["intention"] = intention
if intention.modifies_packages():
return AuditablePoetryCommand(args, **kwargs)
return PoetryCommand(args, **kwargs)
class AuditablePoetryCommand(PoetryCommand, InstallationAuditMixin):
def patch_source_option(
self, args: List[str], new_source: str = "safety"
) -> Tuple[Optional[str], List[str]]:
"""
Find --source argument and its value in a list of args, create a modified copy
with your custom source, and return both.
Args:
args: List[str] - Command line arguments
Returns:
tuple: (source_value, modified_args, original_args)
"""
source_value = None
modified_args = args.copy()
for i in range(len(args)):
if args[i].startswith("--source="):
# Handle --source=value format
source_value = args[i].split("=", 1)[1]
modified_args[i] = f"--source={new_source}"
break
elif args[i] == "--source" and i < len(args) - 1:
# Handle --source value format
source_value = args[i + 1]
modified_args[i + 1] = new_source
break
return source_value, modified_args
def before(self, ctx: typer.Context):
super().before(ctx)
_, modified_args = self.patch_source_option(self._args)
self._args = modified_args
def after(self, ctx: typer.Context, result: ToolResult):
super().after(ctx, result)
self.handle_installation_audit(ctx, result)

View File

@@ -0,0 +1,4 @@
MSG_SAFETY_SOURCE_NOT_ADDED = "\nError: Safety Firewall could not be added as a source in your pyproject.toml file. You will not be protected from malicious or insecure packages. Please run `safety init` to fix this."
MSG_SAFETY_SOURCE_ADDED = (
"\nSafety Firewall has been added as a source to protect this codebase"
)

View File

@@ -0,0 +1,103 @@
import logging
import shutil
import subprocess
from pathlib import Path
import sys
from typing import Optional
from rich.console import Console
from safety.console import main_console
from safety.tool.constants import (
PYPI_PUBLIC_REPOSITORY_URL,
PYPI_ORGANIZATION_REPOSITORY_URL,
PYPI_PROJECT_REPOSITORY_URL,
)
from safety.tool.resolver import get_unwrapped_command
from safety.utils.pyapp_utils import get_path, get_env
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
logger = logging.getLogger(__name__)
class Poetry:
@classmethod
def is_installed(cls) -> bool:
"""
Checks if the Poetry program is installed
Returns:
True if Poetry is installed on system, or false otherwise
"""
return shutil.which("poetry", path=get_path()) is not None
@classmethod
def is_poetry_project_file(cls, file: Path) -> bool:
try:
cfg = tomllib.loads(file.read_text())
# First check: tool.poetry section (most definitive)
if "tool" in cfg and "poetry" in cfg.get("tool", {}):
return True
# Extra check on build-system section
build_backend = cfg.get("build-system", {}).get("build-backend", "")
if build_backend and "poetry.core" in build_backend:
return True
return False
except (IOError, ValueError):
return False
@classmethod
def configure_pyproject(
cls,
file: Path,
org_slug: Optional[str],
project_id: Optional[str] = None,
console: Console = main_console,
) -> Optional[Path]:
"""
Configures index url for specified requirements file.
Args:
file (Path): Path to requirements.txt file.
org_slug (Optional[str]): Organization slug.
project_id (Optional[str]): Project ID.
console (Console): Console instance.
"""
if not cls.is_installed():
logger.error("Poetry is not installed.")
return None
repository_url = (
PYPI_PROJECT_REPOSITORY_URL.format(org_slug, project_id)
if project_id and org_slug
else (
PYPI_ORGANIZATION_REPOSITORY_URL.format(org_slug)
if org_slug
else PYPI_PUBLIC_REPOSITORY_URL
)
)
result = subprocess.run(
[
get_unwrapped_command(name="poetry"),
"source",
"add",
"safety",
repository_url,
],
capture_output=True,
env=get_env(),
)
if result.returncode != 0:
logger.error(f"Failed to configure {file} file")
return None
return file

View File

@@ -0,0 +1,134 @@
from typing import Dict, Optional, Union, Set
from ..base import ToolCommandLineParser
from ..intents import Dependency, ToolIntentionType
class PoetryParser(ToolCommandLineParser):
def get_tool_name(self) -> str:
return "poetry"
def get_command_hierarchy(self) -> Dict[str, Union[ToolIntentionType, Dict]]:
"""
Allow base parser to recognize poetry commands and intentions.
"""
return {
"add": ToolIntentionType.ADD_PACKAGE,
"remove": ToolIntentionType.REMOVE_PACKAGE,
"update": ToolIntentionType.UPDATE_PACKAGE,
"install": ToolIntentionType.SYNC_PACKAGES,
"build": ToolIntentionType.BUILD_PROJECT,
"show": ToolIntentionType.LIST_PACKAGES,
"init": ToolIntentionType.INIT_PROJECT,
}
def get_known_flags(self) -> Dict[str, Set[str]]:
"""
Flags that DO NOT take a value, derived from `poetry --help` and subcommand helps.
"""
return {
"global": {
"help",
"h",
"quiet",
"q",
"version",
"V",
"ansi",
"no-ansi",
"no-interaction",
"n",
"no-plugins",
"no-cache",
"verbose",
"v",
"vv",
"vvv",
},
"add": {
"dev",
"D",
"editable",
"e",
"allow-prereleases",
"dry-run",
"lock",
},
"remove": {
"dev",
"D",
"dry-run",
"lock",
},
"update": {
"sync",
"dry-run",
"lock",
},
"install": {
"sync",
"no-root",
"no-directory",
"dry-run",
"all-extras",
"all-groups",
"only-root",
"compile",
},
"build": {
"clean",
},
}
def _parse_package_spec(
self, spec_str: str, arg_index: int
) -> Optional[Dependency]:
"""
Parse a package specification string into a Dependency object.
Handles various formats including Poetry-specific syntax and standard PEP 508 requirements.
Args:
spec_str: Package specification string (e.g. "requests>=2.25.0[security]")
Returns:
Dependency: Parsed dependency information
Raises:
ValueError: If the specification cannot be parsed
"""
try:
# TODO: This is a very basic implementation and not well tested
# our main target for now is to get the package name.
from packaging.requirements import Requirement
include_specifier = False
# Handle @ operator (package@version)
if "@" in spec_str and not spec_str.startswith("git+"):
name = spec_str.split("@")[0]
# Handle caret requirements (package^version)
elif "^" in spec_str:
name = spec_str.split("^")[0]
# Handle tilde requirements (package~version)
elif "~" in spec_str and "~=" not in spec_str:
name = spec_str.split("~")[0]
else:
# Common PEP 440 cases
name = spec_str
include_specifier = True
req = Requirement(name)
return Dependency(
name=req.name,
version_constraint=str(req.specifier) if include_specifier else None,
extras=req.extras,
arg_index=arg_index,
original_text=spec_str,
)
except Exception:
# If spec parsing fails, just ignore for now
return None

View File

@@ -0,0 +1,66 @@
import sys
import subprocess
import shutil
import logging
from safety.utils.pyapp_utils import get_path, get_env
logger = logging.getLogger(__name__)
def get_unwrapped_command(name: str) -> str:
"""
Find the true executable for a command, skipping wrappers/aliases/.bat files.
Args:
command: The command to resolve (e.g. 'pip', 'python')
Returns:
Path to the actual executable
"""
logger.debug(f"get_unwrapped_command called with name: {name}")
if sys.platform in ["win32"]:
for lookup_term in [f"{name}.exe", name]:
logger.debug(f"Windows platform detected, looking for: {lookup_term}")
where_result = subprocess.run(
["where.exe", lookup_term],
capture_output=True,
text=True,
env=get_env(),
)
logger.debug(f"where.exe returncode: {where_result.returncode}")
logger.debug(f"where.exe stdout: {where_result.stdout}")
logger.debug(f"where.exe stderr: {where_result.stderr}")
if where_result.returncode == 0:
for path in where_result.stdout.splitlines():
path = path.strip()
if not path:
continue
logger.debug(f"Checking path: {path}")
path_lower = path.lower()
if not path_lower.endswith((".exe", ".bat", ".cmd")):
logger.debug(f"Skipping non-executable path: {path}")
continue
if "\\safety\\" in path_lower and path_lower.endswith(
f"{name}.bat"
):
logger.debug(f"Skipping Safety wrapper: {path}")
continue
return path
logger.debug(
f"No unwrapped command found on Windows, returning bare name: {name}"
)
return name
fallback = shutil.which(name, path=get_path()) or name
logger.debug(f"Using fallback (shutil.which or name): {fallback}")
return fallback

View File

@@ -0,0 +1,276 @@
import asyncio
import functools
import os
import platform
import re
import shutil
from typing import Dict, List, Optional, Set, Union
from safety_schemas.models.events.payloads import ToolStatus, AliasConfig, IndexConfig
from safety_schemas.models.events.types import ToolType
from safety.utils.pyapp_utils import get_path
import logging
logger = logging.getLogger(__name__)
class ToolInspector:
"""
Inspects the system for installed tools managers and their versions.
"""
COMMON_LOCATIONS = {
# Common paths across many tools
"COMMON": ["/usr/local/bin", "/usr/bin", "~/.local/bin"],
# Tool-specific paths
ToolType.PIP: [
# Virtual environments
"venv/bin",
"env/bin",
".venv/bin",
".env/bin",
# Python installations
"/opt/python*/bin",
# Windows specific
"C:/Python*/Scripts",
"%APPDATA%/Python/Python*/Scripts",
# macOS specific
"/Library/Frameworks/Python.framework/Versions/*/bin",
],
ToolType.POETRY: [
"~/.poetry/bin",
# Windows
"%APPDATA%/Python/poetry/bin",
"%USERPROFILE%/.poetry/bin",
],
ToolType.CONDA: [
"~/miniconda3/bin",
"~/anaconda3/bin",
"/opt/conda/bin",
"/opt/miniconda3/bin",
"/opt/anaconda3/bin",
# Windows
"C:/ProgramData/Miniconda3",
"C:/ProgramData/Anaconda3",
"%USERPROFILE%/Miniconda3",
"%USERPROFILE%/Anaconda3",
],
ToolType.UV: [
"~/.cargo/bin",
# Windows
"%USERPROFILE%/.cargo/bin",
],
ToolType.NPM: [
"~/.nvm/versions/node/*/bin",
# Windows
"%APPDATA%/npm",
"C:/Program Files/nodejs",
],
}
# Command arguments to check version
VERSION_ARGS = {
ToolType.PIP: "--version",
ToolType.UV: "--version",
ToolType.NPM: "--version",
ToolType.POETRY: "--version",
ToolType.CONDA: "--version",
}
# Version parsing regex
VERSION_REGEX = {
ToolType.PIP: r"pip (\d+\.\d+(?:\.\d+)?)",
ToolType.UV: r"uv (\d+\.\d+(?:\.\d+)?)",
ToolType.NPM: r"(\d+\.\d+\.\d+)",
ToolType.POETRY: r"Poetry version (\d+\.\d+\.\d+)",
ToolType.CONDA: r"conda (\d+\.\d+(?:\.\d+)?)",
}
def __init__(self, timeout: float = 1.0):
"""
Initialize the detector.
Args:
timeout: Command execution timeout in seconds
"""
self.timeout = timeout
self._found_paths: Dict[ToolType, Set[str]] = {t: set() for t in ToolType}
# TODO: limit concurrency
async def inspect_all_tools(self) -> List[ToolStatus]:
"""
Inspect all tools installed in the system.
Returns:
List of ToolStatus objects for each found tool
"""
tasks = []
for tool_type in ToolType:
tasks.append(self._find_tool_instances(tool_type))
results: List[Union[List[ToolStatus], BaseException]] = await asyncio.gather(
*tasks, return_exceptions=True
)
tools_inspected: List[ToolStatus] = []
for tool_status in results:
if isinstance(tool_status, list):
tools_inspected.extend(tool_status)
return tools_inspected
async def _find_tool_instances(self, tool_type: ToolType) -> List[ToolStatus]:
"""
Find all instances of a specific tool type.
"""
# Find all executable paths
paths = await self._find_executable_paths(tool_type)
tasks = [self._check_tool(tool_type, path) for path in paths]
results: List[
Optional[Union[ToolStatus, BaseException]]
] = await asyncio.gather(*tasks, return_exceptions=True)
tools_inspected: List[ToolStatus] = []
for tool_status in results:
if isinstance(tool_status, ToolStatus):
tools_inspected.append(tool_status)
return tools_inspected
def _search_executable_paths(self, tool_type: ToolType) -> Set[str]:
# Get the executable name
exe_name = tool_type.value
if platform.system() == "Windows":
exe_name = f"{exe_name}.exe"
paths = set()
path_result = shutil.which(exe_name, path=get_path())
if path_result:
paths.add(os.path.abspath(path_result))
for location_pattern in (
self.COMMON_LOCATIONS["COMMON"] + self.COMMON_LOCATIONS[tool_type]
):
if location_pattern.startswith("~"):
location_pattern = os.path.expanduser(location_pattern)
if "%" in location_pattern:
location_pattern = os.path.expandvars(location_pattern)
# Handle wildcards
if "*" in location_pattern:
# This is a simplified wildcard expansion - a more robust implementation
# would use glob or similar, but this is faster for common cases
base_dir = location_pattern.split("*")[0]
if os.path.exists(base_dir):
for root, dirs, files in os.walk(base_dir, followlinks=False):
if exe_name in files:
exe_path = os.path.join(root, exe_name)
if os.access(exe_path, os.X_OK):
paths.add(os.path.abspath(exe_path))
else:
# Direct path check
exe_path = os.path.join(location_pattern, exe_name)
if os.path.exists(exe_path) and os.access(exe_path, os.X_OK):
paths.add(os.path.abspath(exe_path))
return paths
async def _find_executable_paths(self, tool_type: ToolType) -> Set[str]:
"""
Find all executable paths for a tool type.
"""
if self._found_paths[tool_type]:
return self._found_paths[tool_type]
paths = await asyncio.get_event_loop().run_in_executor(
None, functools.partial(self._search_executable_paths, tool_type)
)
self._found_paths[tool_type] = paths
return paths
async def _kill_process(self, proc):
"""
Helper method to kill a process safely.
"""
if proc is None:
return
try:
proc.kill()
await asyncio.wait_for(proc.wait(), timeout=1.0)
except Exception:
logger.exception("Error killing process")
async def _check_tool(self, tool_type: ToolType, path: str) -> Optional[ToolStatus]:
"""
Check if a tool at a specific path is reachable and get its version.
"""
proc = None
try:
version_arg = self.VERSION_ARGS[tool_type]
proc = await asyncio.create_subprocess_exec(
path,
version_arg,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
proc.communicate(), timeout=self.timeout
)
# Get data we need
output = stdout.decode() + stderr.decode()
returncode = proc.returncode
# Clear references to help garbage collection
proc = None
# Extract version
version_match = re.search(self.VERSION_REGEX[tool_type], output)
version = version_match.group(1) if version_match else "unknown"
AliasConfig(is_configured=True)
IndexConfig(is_configured=True)
return ToolStatus(
type=tool_type,
command_path=path,
version=version,
reachable=returncode == 0,
alias_config=None,
index_config=None,
)
except (asyncio.TimeoutError, TimeoutError):
if proc:
await self._kill_process(proc)
# Clear references to help garbage collection
proc = None
# Command timed out
return ToolStatus(
type=tool_type,
command_path=path,
version="unknown",
reachable=False,
)
except Exception:
logger.exception("Error checking tool")
# Any other error means the tool is not reachable
if proc:
await self._kill_process(proc)
# Clear reference to help garbage collection
proc = None
return ToolStatus(
type=tool_type, command_path=path, version="unknown", reachable=False
)

View File

@@ -0,0 +1,82 @@
"""
Typosquatting detection for various tools.
"""
import logging
import nltk
from typing import Tuple
from safety.console import main_console as console
from rich.prompt import Prompt
from .intents import CommandToolIntention, ToolIntentionType
logger = logging.getLogger(__name__)
class TyposquattingProtection:
"""
Base class for typosquatting detection.
"""
def __init__(self, popular_packages: Tuple[str]):
self.popular_packages = popular_packages
def check_package(self, package_name: str) -> Tuple[bool, str]:
"""
Check if a package name is likely to be a typosquatting attempt.
Args:
package_name: Name of the package to check
Returns:
Tuple of (is_valid, suggested_package_name)
"""
max_edit_distance = 2 if len(package_name) > 5 else 1
if package_name in self.popular_packages:
return (True, package_name)
for pkg in self.popular_packages:
if (
abs(len(pkg) - len(package_name)) <= max_edit_distance
and nltk.edit_distance(pkg, package_name) <= max_edit_distance
):
return (False, pkg)
return (True, package_name)
def coerce(self, intention: CommandToolIntention, dependency_name: str) -> str:
"""
Coerce a package name to its correct name if it is a typosquatting attempt.
Args:
intention: CommandToolIntention object
dependency_name: Name of the package to coerce
Returns:
str: Coerced package name
"""
(valid, candidate_package_name) = self.check_package(dependency_name)
if not valid:
action = "install"
if intention.intention_type == ToolIntentionType.DOWNLOAD_PACKAGE:
action = "download"
elif intention.intention_type == ToolIntentionType.BUILD_PROJECT:
action = "build"
elif intention.intention_type == ToolIntentionType.SEARCH_PACKAGES:
action = "search"
prompt = f"You are about to {action} {dependency_name} package. Did you mean to {action} {candidate_package_name}?"
answer = Prompt.ask(
prompt=prompt,
choices=["y", "n"],
default="y",
show_default=True,
console=console,
).lower()
if answer == "y":
return candidate_package_name
return dependency_name

View File

@@ -0,0 +1,161 @@
import abc
import os.path
import re
from pathlib import Path
from sys import platform
from safety.tool.pip import Pip
from safety.tool.poetry import Poetry
from safety.tool.npm import Npm
from typing import Any, TYPE_CHECKING, Optional
from safety.tool.uv.main import Uv
if TYPE_CHECKING:
pass
def is_os_supported():
return platform in ["linux", "linux2", "darwin", "win32"]
class BuildFileConfigurator(abc.ABC):
@abc.abstractmethod
def is_supported(self, file: Path) -> bool:
"""
Returns whether a specific file is supported by this class.
Args:
file (str): The file to check.
Returns:
bool: Whether the file is supported by this class.
"""
pass
@abc.abstractmethod
def configure(
self, file: Path, org_slug: Optional[str], project_id: Optional[str]
) -> Optional[Path]:
"""
Configures specific file.
Args:
file (str): The file to configure.
org_slug (str): The organization slug.
project_id (str): The project identifier.
"""
pass
class PipRequirementsConfigurator(BuildFileConfigurator):
__file_name_pattern = re.compile("^([a-zA-Z_-]+)?requirements([a-zA-Z_-]+)?.txt$")
def is_supported(self, file: Path) -> bool:
return self.__file_name_pattern.match(os.path.basename(file)) is not None
def configure(
self, file: Path, org_slug: Optional[str], project_id: Optional[str]
) -> None:
Pip.configure_requirements(file, org_slug, project_id)
class PoetryPyprojectConfigurator(BuildFileConfigurator):
__file_name_pattern = re.compile("^pyproject.toml$")
def is_supported(self, file: Path) -> bool:
return self.__file_name_pattern.match(
os.path.basename(file)
) is not None and Poetry.is_poetry_project_file(file)
def configure(
self, file: Path, org_slug: Optional[str], project_id: Optional[str]
) -> Optional[Path]:
if self.is_supported(file):
return Poetry.configure_pyproject(file, org_slug, project_id) # type: ignore
return None
# TODO: Review if we should move this/hook up this into interceptors.
class ToolConfigurator(abc.ABC):
@abc.abstractmethod
def configure(self, org_slug: Optional[str]) -> Any:
"""
Configures specific tool.
Args:
org_slug (str): The organization slug.
"""
pass
@abc.abstractmethod
def reset(self) -> None:
"""
Resets specific tool.
"""
pass
class PipConfigurator(ToolConfigurator):
def configure(self, org_slug: Optional[str]) -> Optional[Path]:
return Pip.configure_system(org_slug)
def reset(self) -> None:
Pip.reset_system()
class PoetryConfigurator(ToolConfigurator):
"""
Configures poetry system is not supported.
"""
def configure(self, org_slug: Optional[str]) -> Optional[Path]:
return None
def reset(self) -> None:
return None
class UvConfigurator(ToolConfigurator):
def configure(self, org_slug: Optional[str]) -> Optional[Path]:
return Uv.configure_system(org_slug)
def reset(self) -> None:
Uv.reset_system()
class UvPyprojectConfigurator(BuildFileConfigurator):
__file_name_pattern = re.compile("^uv.lock$")
def is_supported(self, file: Path) -> bool:
return (
self.__file_name_pattern.match(os.path.basename(file)) is not None
and Path("pyproject.toml").exists()
)
def configure(
self, file: Path, org_slug: Optional[str], project_id: Optional[str]
) -> Optional[Path]:
if self.is_supported(file):
return Uv.configure_pyproject(Path("pyproject.toml"), org_slug, project_id)
return None
class NpmConfigurator(ToolConfigurator):
def configure(self, org_slug: Optional[str]) -> Optional[Path]:
return Npm.configure_system(org_slug)
def reset(self) -> None:
Npm.reset_system()
class NpmProjectConfigurator(BuildFileConfigurator):
__file_name_pattern = re.compile("^package.json$")
def is_supported(self, file: Path) -> bool:
return self.__file_name_pattern.match(os.path.basename(file)) is not None
def configure(
self, file: Path, org_slug: Optional[str], project_id: Optional[str]
) -> Optional[Path]:
if self.is_supported(file):
return Npm.configure_project(file, org_slug, project_id)
return None

View File

@@ -0,0 +1,121 @@
from typing import List, Optional
from pathlib import Path
from .main import Uv
import typer
from safety.tool.auth import index_credentials
from ..base import BaseCommand
from ..environment_diff import EnvironmentDiffTracker, PipEnvironmentDiffTracker
from ..mixins import InstallationAuditMixin
from safety_schemas.models.events.types import ToolType
from safety.models import ToolResult
from .parser import UvParser
class UvCommand(BaseCommand):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def get_command_name(self) -> List[str]:
return ["uv"]
def get_diff_tracker(self) -> "EnvironmentDiffTracker":
return PipEnvironmentDiffTracker()
def get_tool_type(self) -> ToolType:
return ToolType.UV
def get_package_list_command(self) -> List[str]:
# uv --active flag would ignore the uv project virtual environment,
# by passing the --active flag then we can list the packages for the
# correct environment.
active = (
["--active"]
if self._intention and self._intention.options.get("active")
else []
)
list_pkgs = Path(__file__).parent / "list_pkgs.py"
# --no-project flag is used to avoid uv to create the venv or lock file if it doesn't exist
return [
*self.get_command_name(),
"run",
*active,
"--no-sync",
"python",
str(list_pkgs),
]
@classmethod
def from_args(cls, args: List[str], **kwargs):
if uv_intention := UvParser().parse(args):
kwargs["intention"] = uv_intention
if uv_intention.modifies_packages():
return AuditableUvCommand(args, **kwargs)
return UvCommand(args, **kwargs)
class AuditableUvCommand(UvCommand, InstallationAuditMixin):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.__index_url = None
def before(self, ctx: typer.Context):
super().before(ctx)
args: List[Optional[str]] = self._args.copy() # type: ignore
if self._intention:
if index_opt := self._intention.options.get(
"index-url"
) or self._intention.options.get("i"):
index_value = index_opt["value"]
if index_value and index_value.startswith("https://pkgs.safetycli.com"):
self.__index_url = index_value
arg_index = index_opt["arg_index"]
value_index = index_opt["value_index"]
if (
arg_index
and value_index
and arg_index < len(args)
and value_index < len(args)
):
args[arg_index] = None
args[value_index] = None
self._args = [arg for arg in args if arg is not None]
def after(self, ctx: typer.Context, result: ToolResult):
super().after(ctx, result)
self.handle_installation_audit(ctx, result)
def env(self, ctx: typer.Context) -> dict:
env = super().env(ctx)
default_index_url = Uv.build_index_url(ctx, self.__index_url)
# uv config precedence:
# 1. Command line args -> We rewrite the args if the a default index is provided via command line args.
# 2. Environment variables -> We set the default index to the Safety index
# 3. Config files
env.update(
{
# Default index URL
# When the package manager is wrapped, we provide a default index so the search always falls back to the Safety index
# UV_INDEX_URL is deprecated by UV, we comment it out to avoid a anoying warning, UV_DEFAULT_INDEX is available since uv 0.4.23
# So we decided to support only UV_DEFAULT_INDEX, as we don't inject the uv version in the command pipeline yet.
#
# "UV_INDEX_URL": default_index_url,
#
"UV_DEFAULT_INDEX": default_index_url,
# Credentials for the named index in case of being set in the pyproject.toml
"UV_INDEX_SAFETY_USERNAME": "user",
"UV_INDEX_SAFETY_PASSWORD": index_credentials(ctx),
}
)
return env

View File

@@ -0,0 +1,40 @@
import importlib.metadata as md
import json
import os
def get_package_location(dist):
"""
Get the installation location of a package distribution.
"""
try:
if hasattr(dist, "locate_file") and callable(dist.locate_file):
root = dist.locate_file("")
if root:
return os.path.abspath(str(root))
except (AttributeError, OSError, TypeError):
pass
return ""
def main() -> int:
"""
List all installed packages with their versions and locations.
"""
packages = []
for dist in md.distributions():
packages.append(
{
"name": dist.metadata.get("Name", ""),
"version": dist.version,
"location": get_package_location(dist),
}
)
print(json.dumps(packages, separators=(",", ":")))
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,239 @@
import logging
import os
from pathlib import Path
import shutil
import sys
from typing import Any, Dict, Optional
import tomlkit
import typer
from rich.console import Console
from safety.console import main_console
from safety.tool.auth import build_index_url
from safety.tool.constants import (
PYPI_ORGANIZATION_REPOSITORY_URL,
PYPI_PUBLIC_REPOSITORY_URL,
PYPI_PROJECT_REPOSITORY_URL,
)
from safety.utils.pyapp_utils import get_path
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
logger = logging.getLogger(__name__)
def backup_file(path: Path) -> None:
"""
Create backup of file if it exists
"""
if path.exists():
backup_path = path.with_name(f"{path.name}.backup")
shutil.copy2(path, backup_path)
class Uv:
@classmethod
def is_installed(cls) -> bool:
"""
Checks if the UV program is installed
Returns:
True if UV is installed on system, or false otherwise
"""
return shutil.which("uv", path=get_path()) is not None
@classmethod
def is_uv_project_file(cls, file: Path) -> bool:
try:
cfg = tomllib.loads(file.read_text())
return (
cfg.get("tool", {}).get("uv") is not None
or (file.parent / "uv.lock").exists()
)
except (IOError, ValueError):
return False
@classmethod
def configure_pyproject(
cls,
file: Path,
org_slug: Optional[str],
project_id: Optional[str] = None,
console: Console = main_console,
) -> Optional[Path]:
"""
Configures index url for specified pyproject.toml file.
Args:
file (Path): Path to pyproject.toml file.
org_slug (Optional[str]): Organization slug.
project_id (Optional[str]): Project ID.
console (Console): Console instance.
"""
if not cls.is_installed():
logger.error("UV is not installed.")
return None
repository_url = (
PYPI_PROJECT_REPOSITORY_URL.format(org_slug, project_id)
if project_id and org_slug
else (
PYPI_ORGANIZATION_REPOSITORY_URL.format(org_slug)
if org_slug
else PYPI_PUBLIC_REPOSITORY_URL
)
)
try:
content = file.read_text()
doc: Dict[str, Any] = tomlkit.loads(content)
if "tool" not in doc:
doc["tool"] = tomlkit.table()
if "uv" not in doc["tool"]: # type: ignore
doc["tool"]["uv"] = tomlkit.table() # type: ignore
if "index" not in doc["tool"]["uv"]: # type: ignore
doc["tool"]["uv"]["index"] = tomlkit.aot() # type: ignore
index_container = doc["tool"]["uv"] # type: ignore
cls.filter_out_safety_index(index_container)
safety_index = {
"name": "safety",
"url": repository_url,
# In UV default:
# True = lowest priority
# False = highest priority
"default": False,
}
non_safety_indexes = (
doc.get("tool", {}).get("uv", {}).get("index", tomlkit.aot())
)
# Add safety index as first priority
index_container["index"] = tomlkit.aot() # type: ignore
index_container["index"].append(safety_index) # type: ignore
index_container["index"].extend(non_safety_indexes) # type: ignore
# Write back to file
file.write_text(tomlkit.dumps(doc))
return file
except (IOError, ValueError, Exception) as e:
logger.error(f"Failed to configure {file} file: {e}")
return None
@classmethod
def get_user_config_path(cls) -> Path:
"""
Returns the path to the user config file for UV.
This logic is based on the uv documentation:
https://docs.astral.sh/uv/configuration/files/
"uv will also discover user-level configuration at
~/.config/uv/uv.toml (or $XDG_CONFIG_HOME/uv/uv.toml) on macOS and Linux,
or %APPDATA%\\uv\\uv.toml on Windows; ..."
Returns:
Path: The path to the user config file.
"""
if sys.platform == "win32":
return Path(os.environ.get("APPDATA", ""), "uv", "uv.toml")
else:
xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
if xdg_config_home:
return Path(xdg_config_home, "uv", "uv.toml")
else:
return Path(Path.home(), ".config", "uv", "uv.toml")
@classmethod
def filter_out_safety_index(cls, index_container: Any):
if "index" not in index_container:
return
indexes = list(index_container["index"])
index_container["index"] = tomlkit.aot()
for index in indexes:
index_url = index.get("url", "")
if ".safetycli.com" in index_url:
continue
index_container["index"].append(index)
@classmethod
def configure_system(
cls, org_slug: Optional[str], console: Console = main_console
) -> Optional[Path]:
"""
Configures UV system to use to Safety index url.
"""
try:
repository_url = (
PYPI_ORGANIZATION_REPOSITORY_URL.format(org_slug)
if org_slug
else PYPI_PUBLIC_REPOSITORY_URL
)
user_config_path = cls.get_user_config_path()
if not user_config_path.exists():
user_config_path.parent.mkdir(parents=True, exist_ok=True)
content = ""
else:
backup_file(user_config_path)
content = user_config_path.read_text()
doc = tomlkit.loads(content)
if "index" not in doc:
doc["index"] = tomlkit.aot()
cls.filter_out_safety_index(index_container=doc)
safety_index = tomlkit.aot()
safety_index.append(
{
"name": "safety",
"url": repository_url,
# In UV default:
# True = lowest priority
# False = highest priority
"default": False,
}
)
non_safety_indexes = doc.get("index", tomlkit.aot())
# Add safety index as first priority
doc["index"] = tomlkit.aot()
doc.append("index", safety_index)
doc.append("index", non_safety_indexes)
user_config_path.write_text(tomlkit.dumps(doc))
return user_config_path
except Exception as e:
logger.error(f"Failed to configure UV system: {e}")
return None
@classmethod
def reset_system(cls, console: Console = main_console):
try:
user_config_path = cls.get_user_config_path()
if user_config_path.exists():
backup_file(user_config_path)
content = user_config_path.read_text()
doc = tomlkit.loads(content)
cls.filter_out_safety_index(index_container=doc)
user_config_path.write_text(tomlkit.dumps(doc))
except Exception as e:
msg = "Failed to reset UV global settings"
logger.error(f"{msg}: {e}")
@classmethod
def build_index_url(cls, ctx: typer.Context, index_url: Optional[str]) -> str:
return build_index_url(ctx, index_url, "pypi")

View File

@@ -0,0 +1,160 @@
from typing import Dict, Union, Set
from ..base import ToolCommandLineParser
from ..intents import ToolIntentionType
UV_CACHE_FLAGS = {
"no-cache",
"n",
"refresh",
}
UV_PYTHON_FLAGS = {
"managed-python",
"no-managed-python",
"no-python-downloads",
}
UV_INDEX_FLAGS = {
"no-index",
}
UV_RESOLVER_FLAGS = {
"upgrade",
"U",
"no-sources",
}
UV_INSTALLER_FLAGS = {
"reinstall",
"compile-bytecode",
}
UV_BUILD_FLAGS = {
"no-build-isolation",
"no-build",
"no-binary",
}
UV_GLOBAL_FLAGS = {
"quiet",
"q",
"verbose",
"v",
"native-tls",
"offline",
"no-progress",
"no-config",
"help",
"h",
"version",
"V",
}
UV_PIP_INSTALL_FLAGS = {
"all-extras",
"no-deps",
"require-hashes",
"no-verify-hashes",
"system",
"break-system-packages",
"no-break-system-packages",
"no-build",
"exact",
"strict",
"dry-run",
"user",
}
UV_PIP_UNINSTALL_FLAGS = {
"system",
"break-system-packages",
"no-break-system-packages",
"dry-run",
}
UV_KNOWN_FLAGS: Dict[str, Set[str]] = {
"global": UV_GLOBAL_FLAGS
| UV_CACHE_FLAGS
| UV_PYTHON_FLAGS
| UV_INDEX_FLAGS
| UV_RESOLVER_FLAGS
| UV_INSTALLER_FLAGS
| UV_BUILD_FLAGS,
# 2-level commands
"add": {
# From `uv add --help`
"dev",
"editable",
"raw",
"no-sync",
"locked",
"frozen",
"active",
"workspace",
"no-workspace",
"no-install-project",
"no-install-workspace",
},
"remove": {
"dev",
"no-sync",
"active",
"locked",
"frozen",
},
"sync": {
"all-extras",
"no-dev",
"only-dev",
"no-default-groups",
"all-groups",
"no-editable",
"inexact",
"active",
"no-install-project",
"no-install-workspace",
"locked",
"frozen",
"dry-run",
"all-packages",
"check",
},
# 3-level pip commands
"pip.install": UV_PIP_INSTALL_FLAGS,
"pip.uninstall": UV_PIP_UNINSTALL_FLAGS,
}
class UvParser(ToolCommandLineParser):
def get_tool_name(self) -> str:
return "uv"
def get_command_hierarchy(self) -> Dict[str, Union[ToolIntentionType, Dict]]:
"""
Context for command hierarchy parsing
"""
return {
# 2-level commands
"add": ToolIntentionType.ADD_PACKAGE,
"remove": ToolIntentionType.REMOVE_PACKAGE,
"build": ToolIntentionType.BUILD_PROJECT,
"sync": ToolIntentionType.SYNC_PACKAGES,
# 3-level commands
"pip": {
"install": ToolIntentionType.ADD_PACKAGE,
"uninstall": ToolIntentionType.REMOVE_PACKAGE,
"download": ToolIntentionType.DOWNLOAD_PACKAGE,
"list": ToolIntentionType.LIST_PACKAGES,
},
}
def get_known_flags(self) -> Dict[str, Set[str]]:
"""
Define flags that DON'T take values for uv.
These were derived from `uv --help` and subcommand helps.
"""
return UV_KNOWN_FLAGS