This commit is contained in:
Iliyan Angelov
2025-09-19 11:58:53 +03:00
parent 306b20e24a
commit 6b247e5b9f
11423 changed files with 1500615 additions and 778 deletions

View File

@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
import sys
from django.core.management.base import BaseCommand
from logging import getLogger
logger = getLogger("django.commands")
class LoggingBaseCommand(BaseCommand):
"""
A subclass of BaseCommand that logs run time errors to `django.commands`.
To use this, create a management command subclassing LoggingBaseCommand:
from django_extensions.management.base import LoggingBaseCommand
class Command(LoggingBaseCommand):
help = 'Test error'
def handle(self, *args, **options):
raise Exception
And then define a logging handler in settings.py:
LOGGING = {
... # Other stuff here
'handlers': {
'mail_admins': {
'level': 'ERROR',
'filters': ['require_debug_false'],
'class': 'django.utils.log.AdminEmailHandler'
},
},
'loggers': {
'django.commands': {
'handlers': ['mail_admins'],
'level': 'ERROR',
'propagate': False,
},
}
}
"""
def execute(self, *args, **options):
try:
super().execute(*args, **options)
except Exception as e:
logger.error(e, exc_info=sys.exc_info(), extra={"status_code": 500})
raise

View File

@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
from django.core.management import color
from django.utils import termcolors
def _dummy_style_func(msg):
return msg
def no_style():
style = color.no_style()
for role in ("INFO", "WARN", "BOLD", "URL", "MODULE", "MODULE_NAME", "URL_NAME"):
setattr(style, role, _dummy_style_func)
return style
def color_style():
if color.supports_color():
style = color.color_style()
style.INFO = termcolors.make_style(fg="green")
style.WARN = termcolors.make_style(fg="yellow")
style.BOLD = termcolors.make_style(opts=("bold",))
style.URL = termcolors.make_style(fg="green", opts=("bold",))
style.MODULE = termcolors.make_style(fg="yellow")
style.MODULE_NAME = termcolors.make_style(opts=("bold",))
style.URL_NAME = termcolors.make_style(fg="red")
else:
style = no_style()
return style

View File

@@ -0,0 +1,395 @@
# -*- coding: utf-8 -*-
"""
The Django Admin Generator is a project which can automatically generate
(scaffold) a Django Admin for you. By doing this it will introspect your
models and automatically generate an Admin with properties like:
- `list_display` for all local fields
- `list_filter` for foreign keys with few items
- `raw_id_fields` for foreign keys with a lot of items
- `search_fields` for name and `slug` fields
- `prepopulated_fields` for `slug` fields
- `date_hierarchy` for `created_at`, `updated_at` or `joined_at` fields
The original source and latest version can be found here:
https://github.com/WoLpH/django-admin-generator/
"""
import re
from django.apps import apps
from django.conf import settings
from django.core.management.base import LabelCommand, CommandError
from django.db import models
from django_extensions.management.utils import signalcommand
# Configurable constants
MAX_LINE_WIDTH = getattr(settings, "MAX_LINE_WIDTH", 78)
INDENT_WIDTH = getattr(settings, "INDENT_WIDTH", 4)
LIST_FILTER_THRESHOLD = getattr(settings, "LIST_FILTER_THRESHOLD", 25)
RAW_ID_THRESHOLD = getattr(settings, "RAW_ID_THRESHOLD", 100)
LIST_FILTER = getattr(
settings,
"LIST_FILTER",
(
models.DateField,
models.DateTimeField,
models.ForeignKey,
models.BooleanField,
),
)
SEARCH_FIELD_NAMES = getattr(
settings,
"SEARCH_FIELD_NAMES",
(
"name",
"slug",
),
)
DATE_HIERARCHY_NAMES = getattr(
settings,
"DATE_HIERARCHY_NAMES",
(
"joined_at",
"updated_at",
"created_at",
),
)
PREPOPULATED_FIELD_NAMES = getattr(settings, "PREPOPULATED_FIELD_NAMES", ("slug=name",))
PRINT_IMPORTS = getattr(
settings,
"PRINT_IMPORTS",
"""# -*- coding: utf-8 -*-
from django.contrib import admin
from .models import %(models)s
""",
)
PRINT_ADMIN_CLASS = getattr(
settings,
"PRINT_ADMIN_CLASS",
"""
@admin.register(%(name)s)
class %(name)sAdmin(admin.ModelAdmin):%(class_)s
""",
)
PRINT_ADMIN_PROPERTY = getattr(
settings,
"PRINT_ADMIN_PROPERTY",
"""
%(key)s = %(value)s""",
)
class UnicodeMixin:
"""
Mixin class to handle defining the proper __str__/__unicode__
methods in Python 2 or 3.
"""
def __str__(self):
return self.__unicode__()
class AdminApp(UnicodeMixin):
def __init__(self, app_config, model_res, **options):
self.app_config = app_config
self.model_res = model_res
self.options = options
def __iter__(self):
for model in self.app_config.get_models():
admin_model = AdminModel(model, **self.options)
for model_re in self.model_res:
if model_re.search(admin_model.name):
break
else:
if self.model_res:
continue
yield admin_model
def __unicode__(self):
return "".join(self._unicode_generator())
def _unicode_generator(self):
models_list = [admin_model.name for admin_model in self]
yield PRINT_IMPORTS % dict(models=", ".join(models_list))
admin_model_names = []
for admin_model in self:
yield PRINT_ADMIN_CLASS % dict(
name=admin_model.name,
class_=admin_model,
)
admin_model_names.append(admin_model.name)
def __repr__(self):
return "<%s[%s]>" % (
self.__class__.__name__,
self.app.name,
)
class AdminModel(UnicodeMixin):
PRINTABLE_PROPERTIES = (
"list_display",
"list_filter",
"raw_id_fields",
"search_fields",
"prepopulated_fields",
"date_hierarchy",
)
def __init__(
self,
model,
raw_id_threshold=RAW_ID_THRESHOLD,
list_filter_threshold=LIST_FILTER_THRESHOLD,
search_field_names=SEARCH_FIELD_NAMES,
date_hierarchy_names=DATE_HIERARCHY_NAMES,
prepopulated_field_names=PREPOPULATED_FIELD_NAMES,
**options,
):
self.model = model
self.list_display = []
self.list_filter = []
self.raw_id_fields = []
self.search_fields = []
self.prepopulated_fields = {}
self.date_hierarchy = None
self.search_field_names = search_field_names
self.raw_id_threshold = raw_id_threshold
self.list_filter_threshold = list_filter_threshold
self.date_hierarchy_names = date_hierarchy_names
self.prepopulated_field_names = prepopulated_field_names
def __repr__(self):
return "<%s[%s]>" % (
self.__class__.__name__,
self.name,
)
@property
def name(self):
return self.model.__name__
def _process_many_to_many(self, meta):
raw_id_threshold = self.raw_id_threshold
for field in meta.local_many_to_many:
if hasattr(field, "remote_field"):
related_model = getattr(
field.remote_field, "related_model", field.remote_field.model
)
else:
raise CommandError("Unable to process ManyToMany relation")
related_objects = related_model.objects.all()
if related_objects[:raw_id_threshold].count() < raw_id_threshold:
yield field.name
def _process_fields(self, meta):
parent_fields = meta.parents.values()
for field in meta.fields:
name = self._process_field(field, parent_fields)
if name:
yield name
def _process_foreign_key(self, field):
raw_id_threshold = self.raw_id_threshold
list_filter_threshold = self.list_filter_threshold
max_count = max(list_filter_threshold, raw_id_threshold)
if hasattr(field, "remote_field"):
related_model = getattr(
field.remote_field, "related_model", field.remote_field.model
)
else:
raise CommandError("Unable to process ForeignKey relation")
related_count = related_model.objects.all()
related_count = related_count[:max_count].count()
if related_count >= raw_id_threshold:
self.raw_id_fields.append(field.name)
elif related_count < list_filter_threshold:
self.list_filter.append(field.name)
else: # pragma: no cover
pass # Do nothing :)
def _process_field(self, field, parent_fields):
if field in parent_fields:
return
field_name = str(field.name)
self.list_display.append(field_name)
if isinstance(field, LIST_FILTER):
if isinstance(field, models.ForeignKey):
self._process_foreign_key(field)
else:
self.list_filter.append(field_name)
if field.name in self.search_field_names:
self.search_fields.append(field_name)
return field_name
def __unicode__(self):
return "".join(self._unicode_generator())
def _yield_value(self, key, value):
if isinstance(value, (list, set, tuple)):
return self._yield_tuple(key, tuple(value))
elif isinstance(value, dict):
return self._yield_dict(key, value)
elif isinstance(value, str):
return self._yield_string(key, value)
else: # pragma: no cover
raise TypeError("%s is not supported in %r" % (type(value), value))
def _yield_string(self, key, value, converter=repr):
return PRINT_ADMIN_PROPERTY % dict(
key=key,
value=converter(value),
)
def _yield_dict(self, key, value):
row_parts = []
row = self._yield_string(key, value)
if len(row) > MAX_LINE_WIDTH:
row_parts.append(self._yield_string(key, "{", str))
for k, v in value.items():
row_parts.append("%s%r: %r" % (2 * INDENT_WIDTH * " ", k, v))
row_parts.append(INDENT_WIDTH * " " + "}")
row = "\n".join(row_parts)
return row
def _yield_tuple(self, key, value):
row_parts = []
row = self._yield_string(key, value)
if len(row) > MAX_LINE_WIDTH:
row_parts.append(self._yield_string(key, "(", str))
for v in value:
row_parts.append(2 * INDENT_WIDTH * " " + repr(v) + ",")
row_parts.append(INDENT_WIDTH * " " + ")")
row = "\n".join(row_parts)
return row
def _unicode_generator(self):
self._process()
for key in self.PRINTABLE_PROPERTIES:
value = getattr(self, key)
if value:
yield self._yield_value(key, value)
def _process(self):
meta = self.model._meta
self.raw_id_fields += list(self._process_many_to_many(meta))
field_names = list(self._process_fields(meta))
for field_name in self.date_hierarchy_names[::-1]:
if field_name in field_names and not self.date_hierarchy:
self.date_hierarchy = field_name
break
for k in sorted(self.prepopulated_field_names):
k, vs = k.split("=", 1)
vs = vs.split(",")
if k in field_names:
incomplete = False
for v in vs:
if v not in field_names:
incomplete = True
break
if not incomplete:
self.prepopulated_fields[k] = vs
self.processed = True
class Command(LabelCommand):
help = """Generate a `admin.py` file for the given app (models)"""
# args = "[app_name]"
can_import_settings = True
def add_arguments(self, parser):
parser.add_argument("app_name")
parser.add_argument("model_name", nargs="*")
parser.add_argument(
"-s",
"--search-field",
action="append",
default=SEARCH_FIELD_NAMES,
help="Fields named like this will be added to `search_fields`"
" [default: %(default)s]",
)
parser.add_argument(
"-d",
"--date-hierarchy",
action="append",
default=DATE_HIERARCHY_NAMES,
help="A field named like this will be set as `date_hierarchy`"
" [default: %(default)s]",
)
parser.add_argument(
"-p",
"--prepopulated-fields",
action="append",
default=PREPOPULATED_FIELD_NAMES,
help="These fields will be prepopulated by the other field."
"The field names can be specified like `spam=eggA,eggB,eggC`"
" [default: %(default)s]",
)
parser.add_argument(
"-l",
"--list-filter-threshold",
type=int,
default=LIST_FILTER_THRESHOLD,
metavar="LIST_FILTER_THRESHOLD",
help="If a foreign key has less than LIST_FILTER_THRESHOLD items "
"it will be added to `list_filter` [default: %(default)s]",
)
parser.add_argument(
"-r",
"--raw-id-threshold",
type=int,
default=RAW_ID_THRESHOLD,
metavar="RAW_ID_THRESHOLD",
help="If a foreign key has more than RAW_ID_THRESHOLD items "
"it will be added to `list_filter` [default: %(default)s]",
)
@signalcommand
def handle(self, *args, **options):
app_name = options["app_name"]
try:
app = apps.get_app_config(app_name)
except LookupError:
self.stderr.write("This command requires an existing app name as argument")
self.stderr.write("Available apps:")
app_labels = [app.label for app in apps.get_app_configs()]
for label in sorted(app_labels):
self.stderr.write(" %s" % label)
return
model_res = []
for arg in options["model_name"]:
model_res.append(re.compile(arg, re.IGNORECASE))
self.stdout.write(AdminApp(app, model_res, **options).__str__())

View File

@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
import fnmatch
import os
from os.path import join as _j
from typing import List
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import signalcommand
class Command(BaseCommand):
help = "Removes all python bytecode compiled files from the project."
requires_system_checks: List[str] = []
def add_arguments(self, parser):
parser.add_argument(
"--optimize",
"-o",
"-O",
action="store_true",
dest="optimize",
default=False,
help="Remove optimized python bytecode files",
)
parser.add_argument(
"--path",
"-p",
action="store",
dest="path",
help="Specify path to recurse into",
)
@signalcommand
def handle(self, *args, **options):
project_root = options.get("path", getattr(settings, "BASE_DIR", None))
if not project_root:
project_root = getattr(settings, "BASE_DIR", None)
verbosity = options["verbosity"]
if not project_root:
raise CommandError(
"No --path specified and settings.py does not contain BASE_DIR"
)
exts = options["optimize"] and "*.py[co]" or "*.pyc"
for root, dirs, filenames in os.walk(project_root):
for filename in fnmatch.filter(filenames, exts):
full_path = _j(root, filename)
if verbosity > 1:
self.stdout.write("%s\n" % full_path)
os.remove(full_path)

View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Author: AxiaCore S.A.S. https://axiacore.com
from django.conf import settings
from django.core.cache import DEFAULT_CACHE_ALIAS, caches
from django.core.cache.backends.base import InvalidCacheBackendError
from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import signalcommand
class Command(BaseCommand):
"""A simple management command which clears the site-wide cache."""
help = "Fully clear site-wide cache."
def add_arguments(self, parser):
parser.add_argument("--cache", action="append", help="Name of cache to clear")
parser.add_argument(
"--all",
"-a",
action="store_true",
default=False,
dest="all_caches",
help="Clear all configured caches",
)
@signalcommand
def handle(self, cache, all_caches, *args, **kwargs):
if not cache and not all_caches:
cache = [DEFAULT_CACHE_ALIAS]
elif cache and all_caches:
raise CommandError("Using both --all and --cache is not supported")
elif all_caches:
cache = getattr(settings, "CACHES", {DEFAULT_CACHE_ALIAS: {}}).keys()
for key in cache:
try:
caches[key].clear()
except InvalidCacheBackendError:
self.stderr.write('Cache "%s" is invalid!\n' % key)
else:
self.stdout.write('Cache "%s" has been cleared!\n' % key)

View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
import fnmatch
import os
import py_compile
from os.path import join as _j
from typing import List
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import signalcommand
class Command(BaseCommand):
help = "Compile python bytecode files for the project."
requires_system_checks: List[str] = []
def add_arguments(self, parser):
parser.add_argument(
"--path",
"-p",
action="store",
dest="path",
help="Specify path to recurse into",
)
@signalcommand
def handle(self, *args, **options):
project_root = options["path"]
if not project_root:
project_root = getattr(settings, "BASE_DIR", None)
verbosity = options["verbosity"]
if not project_root:
raise CommandError(
"No --path specified and settings.py does not contain BASE_DIR"
)
for root, dirs, filenames in os.walk(project_root):
for filename in fnmatch.filter(filenames, "*.py"):
full_path = _j(root, filename)
if verbosity > 1:
self.stdout.write("Compiling %s...\n" % full_path)
py_compile.compile(full_path)

View File

@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
import os
import sys
import shutil
from typing import List
from django.core.management.base import AppCommand
from django.core.management.color import color_style
from django_extensions.management.utils import _make_writeable, signalcommand
class Command(AppCommand):
help = "Creates a Django management command directory structure for the given app "
"name in the app's directory."
requires_system_checks: List[str] = []
# Can't import settings during this command, because they haven't
# necessarily been created.
can_import_settings = True
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--name",
"-n",
action="store",
dest="command_name",
default="sample",
help="The name to use for the management command",
)
parser.add_argument(
"--base",
"-b",
action="store",
dest="base_command",
default="Base",
help="The base class used for implementation of "
"this command. Should be one of Base, App, Label, or NoArgs",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="Do not actually create any files",
)
@signalcommand
def handle_app_config(self, args, **options):
app = args
copy_template("command_template", app.path, **options)
def copy_template(template_name, copy_to, **options):
"""Copy the specified template directory to the copy_to location"""
import django_extensions
style = color_style()
ERROR = getattr(style, "ERROR", lambda x: x)
SUCCESS = getattr(style, "SUCCESS", lambda x: x)
command_name, base_command = (
options["command_name"],
"%sCommand" % options["base_command"],
)
dry_run = options["dry_run"]
verbosity = options["verbosity"]
template_dir = os.path.join(django_extensions.__path__[0], "conf", template_name)
# walk the template structure and copies it
for d, subdirs, files in os.walk(template_dir):
relative_dir = d[len(template_dir) + 1 :]
if relative_dir and not os.path.exists(os.path.join(copy_to, relative_dir)):
if not dry_run:
os.mkdir(os.path.join(copy_to, relative_dir))
for i, subdir in enumerate(subdirs):
if subdir.startswith("."):
del subdirs[i]
for f in files:
if f.endswith((".pyc", ".pyo")) or f.startswith(
(".DS_Store", "__pycache__")
):
continue
path_old = os.path.join(d, f)
path_new = os.path.join(
copy_to, relative_dir, f.replace("sample", command_name)
).rstrip(".tmpl")
if os.path.exists(path_new):
path_new = os.path.join(copy_to, relative_dir, f).rstrip(".tmpl")
if os.path.exists(path_new):
if verbosity > 1:
print(ERROR("%s already exists" % path_new))
continue
if verbosity > 1:
print(SUCCESS("%s" % path_new))
with open(path_old, "r") as fp_orig:
data = fp_orig.read()
data = data.replace("{{ command_name }}", command_name)
data = data.replace("{{ base_command }}", base_command)
if not dry_run:
with open(path_new, "w") as fp_new:
fp_new.write(data)
if not dry_run:
try:
shutil.copymode(path_old, path_new)
_make_writeable(path_new)
except OSError:
sys.stderr.write(
"Notice: Couldn't set permission bits on %s. You're probably using an uncommon filesystem setup. No problem.\n" # noqa: E501
% path_new
)

View File

@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
import os
import sys
import shutil
from typing import List
from django.core.management.base import AppCommand
from django.core.management.color import color_style
from django_extensions.management.utils import _make_writeable, signalcommand
class Command(AppCommand):
help = "Creates a Django jobs command directory structure for the given app name "
"in the current directory."
requires_system_checks: List[str] = []
# Can't import settings during this command, because they haven't
# necessarily been created.
can_import_settings = True
@signalcommand
def handle_app_config(self, app, **options):
copy_template("jobs_template", app.path, **options)
def copy_template(template_name, copy_to, **options):
"""Copy the specified template directory to the copy_to location"""
import django_extensions
style = color_style()
ERROR = getattr(style, "ERROR", lambda x: x)
SUCCESS = getattr(style, "SUCCESS", lambda x: x)
template_dir = os.path.join(django_extensions.__path__[0], "conf", template_name)
verbosity = options["verbosity"]
# walks the template structure and copies it
for d, subdirs, files in os.walk(template_dir):
relative_dir = d[len(template_dir) + 1 :]
if relative_dir and not os.path.exists(os.path.join(copy_to, relative_dir)):
os.mkdir(os.path.join(copy_to, relative_dir))
for i, subdir in enumerate(subdirs):
if subdir.startswith("."):
del subdirs[i]
for f in files:
if f.endswith(".pyc") or f.startswith(".DS_Store"):
continue
path_old = os.path.join(d, f)
path_new = os.path.join(copy_to, relative_dir, f).rstrip(".tmpl")
if os.path.exists(path_new):
if verbosity > 1:
print(ERROR("%s already exists" % path_new))
continue
if verbosity > 1:
print(SUCCESS("%s" % path_new))
with open(path_old, "r") as fp_orig:
with open(path_new, "w") as fp_new:
fp_new.write(fp_orig.read())
try:
shutil.copymode(path_old, path_new)
_make_writeable(path_new)
except OSError:
sys.stderr.write(
"Notice: Couldn't set permission bits on %s. You're probably using an uncommon filesystem setup. No problem.\n" # noqa: E501
% path_new
)

View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
import os
import sys
from typing import List
from django.core.management.base import AppCommand
from django_extensions.management.utils import _make_writeable, signalcommand
class Command(AppCommand):
help = "Creates a Django template tags directory structure for the given app name "
"in the apps's directory"
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--name",
"-n",
action="store",
dest="tag_library_name",
default="appname_tags",
help="The name to use for the template tag base name. "
"Defaults to `appname`_tags.",
)
requires_system_checks: List[str] = []
# Can't import settings during this command, because they haven't
# necessarily been created.
can_import_settings = True
@signalcommand
def handle_app_config(self, app_config, **options):
app_dir = app_config.path
tag_library_name = options["tag_library_name"]
if tag_library_name == "appname_tags":
tag_library_name = "%s_tags" % os.path.basename(app_dir)
copy_template("template_tags_template", app_dir, tag_library_name)
def copy_template(template_name, copy_to, tag_library_name):
"""Copy the specified template directory to the copy_to location"""
import django_extensions
import shutil
template_dir = os.path.join(django_extensions.__path__[0], "conf", template_name)
# walk the template structure and copies it
for d, subdirs, files in os.walk(template_dir):
relative_dir = d[len(template_dir) + 1 :]
if relative_dir and not os.path.exists(os.path.join(copy_to, relative_dir)):
os.mkdir(os.path.join(copy_to, relative_dir))
for i, subdir in enumerate(subdirs):
if subdir.startswith("."):
del subdirs[i]
for f in files:
if f.endswith(".pyc") or f.startswith(".DS_Store"):
continue
path_old = os.path.join(d, f)
path_new = os.path.join(
copy_to, relative_dir, f.replace("sample", tag_library_name)
)
if os.path.exists(path_new):
path_new = os.path.join(copy_to, relative_dir, f)
if os.path.exists(path_new):
continue
path_new = path_new.rstrip(".tmpl")
fp_old = open(path_old, "r")
fp_new = open(path_new, "w")
fp_new.write(fp_old.read())
fp_old.close()
fp_new.close()
try:
shutil.copymode(path_old, path_new)
_make_writeable(path_new)
except OSError:
sys.stderr.write(
"Notice: Couldn't set permission bits on %s. You're probably using an uncommon filesystem setup. No problem.\n" # noqa: E501
% path_new
)

View File

@@ -0,0 +1,214 @@
# -*- coding: utf-8 -*-
import os
import inspect
import re
from django.core.management.base import BaseCommand, CommandError
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.migrations.loader import AmbiguityError, MigrationLoader
REPLACES_REGEX = re.compile(r"\s+replaces\s*=\s*\[[^\]]+\]\s*")
PYC = ".pyc"
def py_from_pyc(pyc_fn):
return pyc_fn[: -len(PYC)] + ".py"
class Command(BaseCommand):
help = (
"Deletes left over migrations that have been replaced by a "
"squashed migration and converts squashed migration into a normal "
"migration. Modifies your source tree! Use with care!"
)
def add_arguments(self, parser):
parser.add_argument(
"app_label",
help="App label of the application to delete replaced migrations from.",
)
parser.add_argument(
"squashed_migration_name",
default=None,
nargs="?",
help="The squashed migration to replace. "
"If not specified defaults to the first found.",
)
parser.add_argument(
"--noinput",
"--no-input",
action="store_false",
dest="interactive",
default=True,
help="Tells Django to NOT prompt the user for input of any kind.",
)
parser.add_argument(
"--dry-run",
action="store_true",
default=False,
help="Do not actually delete or change any files",
)
parser.add_argument(
"--database",
default=DEFAULT_DB_ALIAS,
help=(
"Nominates a database to run command for. "
'Defaults to the "%s" database.'
)
% DEFAULT_DB_ALIAS,
)
def handle(self, **options):
self.verbosity = options["verbosity"]
self.interactive = options["interactive"]
self.dry_run = options["dry_run"]
app_label = options["app_label"]
squashed_migration_name = options["squashed_migration_name"]
database = options["database"]
# Load the current graph state
# check the app and migration they asked for exists
loader = MigrationLoader(connections[database])
if app_label not in loader.migrated_apps:
raise CommandError(
"App '%s' does not have migrations (so delete_squashed_migrations on "
"it makes no sense)" % app_label
)
squashed_migration = None
if squashed_migration_name:
squashed_migration = self.find_migration(
loader, app_label, squashed_migration_name
)
if not squashed_migration.replaces:
raise CommandError(
"The migration %s %s is not a squashed migration."
% (squashed_migration.app_label, squashed_migration.name)
)
else:
leaf_nodes = loader.graph.leaf_nodes(app=app_label)
migration = loader.get_migration(*leaf_nodes[0])
previous_migrations = [
loader.get_migration(al, mn)
for al, mn in loader.graph.forwards_plan(
(migration.app_label, migration.name)
)
if al == migration.app_label
]
migrations = previous_migrations + [migration]
for migration in migrations:
if migration.replaces:
squashed_migration = migration
break
if not squashed_migration:
raise CommandError(
"Cannot find a squashed migration in app '%s'." % (app_label)
)
files_to_delete = []
for al, mn in squashed_migration.replaces:
try:
migration = loader.disk_migrations[al, mn]
except KeyError:
if self.verbosity > 0:
self.stderr.write(
"Couldn't find migration file for %s %s\n" % (al, mn)
)
else:
pyc_file = inspect.getfile(migration.__class__)
files_to_delete.append(pyc_file)
if pyc_file.endswith(PYC):
py_file = py_from_pyc(pyc_file)
files_to_delete.append(py_file)
# Tell them what we're doing and optionally ask if we should proceed
if self.verbosity > 0 or self.interactive:
self.stdout.write(
self.style.MIGRATE_HEADING("Will delete the following files:")
)
for fn in files_to_delete:
self.stdout.write(" - %s" % fn)
if not self.confirm():
return
for fn in files_to_delete:
try:
if not self.dry_run:
os.remove(fn)
except OSError:
if self.verbosity > 0:
self.stderr.write("Couldn't delete %s\n" % (fn,))
# Try and delete replaces only if it's all on one line
squashed_migration_fn = inspect.getfile(squashed_migration.__class__)
if squashed_migration_fn.endswith(PYC):
squashed_migration_fn = py_from_pyc(squashed_migration_fn)
with open(squashed_migration_fn) as fp:
squashed_migration_lines = list(fp)
delete_lines = []
for i, line in enumerate(squashed_migration_lines):
if REPLACES_REGEX.match(line):
delete_lines.append(i)
if i > 0 and squashed_migration_lines[i - 1].strip() == "":
delete_lines.insert(0, i - 1)
break
if not delete_lines:
raise CommandError(
(
"Couldn't find 'replaces =' line in file %s. "
"Please finish cleaning up manually."
)
% (squashed_migration_fn,)
)
if self.verbosity > 0 or self.interactive:
self.stdout.write(
self.style.MIGRATE_HEADING(
"Will delete line %s%s from file %s"
% (
delete_lines[0],
" and " + str(delete_lines[1]) if len(delete_lines) > 1 else "",
squashed_migration_fn,
)
)
)
if not self.confirm():
return
for line_num in sorted(delete_lines, reverse=True):
del squashed_migration_lines[line_num]
with open(squashed_migration_fn, "w") as fp:
if not self.dry_run:
fp.write("".join(squashed_migration_lines))
def confirm(self):
if self.interactive:
answer = None
while not answer or answer not in "yn":
answer = input("Do you wish to proceed? [yN] ")
if not answer:
answer = "n"
break
else:
answer = answer[0].lower()
return answer == "y"
return True
def find_migration(self, loader, app_label, name):
try:
return loader.get_migration_by_prefix(app_label, name)
except AmbiguityError:
raise CommandError(
"More than one migration matches '%s' in app '%s'. Please be "
"more specific." % (name, app_label)
)
except KeyError:
raise CommandError(
"Cannot find a migration matching '%s' from app '%s'."
% (name, app_label)
)

View File

@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
from django.apps import apps
from django.core.management.base import CommandError, LabelCommand
from django.utils.encoding import force_str
from django_extensions.management.utils import signalcommand
class Command(LabelCommand):
help = "Outputs the specified model as a form definition to the shell."
def add_arguments(self, parser):
parser.add_argument("label", type=str, help="application name and model name")
parser.add_argument(
"--fields",
"-f",
action="append",
dest="fields",
default=[],
help="Describe form with these fields only",
)
@signalcommand
def handle(self, *args, **options):
label = options["label"]
fields = options["fields"]
return describe_form(label, fields)
def describe_form(label, fields):
"""Return a string describing a form based on the model"""
try:
app_name, model_name = label.split(".")[-2:]
except (IndexError, ValueError):
raise CommandError("Need application and model name in the form: appname.model")
model = apps.get_model(app_name, model_name)
opts = model._meta
field_list = []
for f in opts.fields + opts.many_to_many:
if not f.editable:
continue
if fields and f.name not in fields:
continue
formfield = f.formfield()
if "__dict__" not in dir(formfield):
continue
attrs = {}
valid_fields = [
"required",
"initial",
"max_length",
"min_length",
"max_value",
"min_value",
"max_digits",
"decimal_places",
"choices",
"help_text",
"label",
]
for k, v in formfield.__dict__.items():
if k in valid_fields and v is not None:
# ignore defaults, to minimize verbosity
if k == "required" and v:
continue
if k == "help_text" and not v:
continue
if k == "widget":
attrs[k] = v.__class__
elif k in ["help_text", "label"]:
attrs[k] = str(force_str(v).strip())
else:
attrs[k] = v
params = ", ".join(["%s=%r" % (k, v) for k, v in sorted(attrs.items())])
field_list.append(
" %(field_name)s = forms.%(field_type)s(%(params)s)"
% {
"field_name": f.name,
"field_type": formfield.__class__.__name__,
"params": params,
}
)
return """
from django import forms
from %(app_name)s.models import %(object_name)s
class %(object_name)sForm(forms.Form):
%(field_list)s
""" % {
"app_name": app_name,
"object_name": opts.object_name,
"field_list": "\n".join(field_list),
}

View File

@@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
import importlib.util
from itertools import count
import os
import logging
import warnings
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django.db import DEFAULT_DB_ALIAS
from django.db.backends.base.creation import TEST_DATABASE_PREFIX
from django_extensions.settings import SQLITE_ENGINES, POSTGRESQL_ENGINES, MYSQL_ENGINES
from django_extensions.management.mysql import parse_mysql_cnf
from django_extensions.management.utils import signalcommand
from django_extensions.utils.deprecation import RemovedInNextVersionWarning
class Command(BaseCommand):
help = "Drops test database for this project."
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--noinput",
"--no-input",
action="store_false",
dest="interactive",
default=True,
help="Tells Django to NOT prompt the user for input of any kind.",
)
parser.add_argument(
"-U",
"--user",
action="store",
dest="user",
default=None,
help="Use another user for the database then defined in settings.py",
)
parser.add_argument(
"-P",
"--password",
action="store",
dest="password",
default=None,
help="Use another password for the database then defined in settings.py",
)
parser.add_argument(
"-D",
"--dbname",
action="store",
dest="dbname",
default=None,
help="Use another database name then defined in settings.py",
)
parser.add_argument(
"-R",
"--router",
action="store",
dest="router",
default=DEFAULT_DB_ALIAS,
help="Use this router-database other then defined in settings.py",
)
parser.add_argument(
"--database",
default=DEFAULT_DB_ALIAS,
help=(
"Nominates a database to run command for. "
'Defaults to the "%s" database.'
)
% DEFAULT_DB_ALIAS,
)
@signalcommand
def handle(self, *args, **options):
"""Drop test database for this project."""
database = options["database"]
if options["router"] != DEFAULT_DB_ALIAS:
warnings.warn(
"--router is deprecated. You should use --database.",
RemovedInNextVersionWarning,
stacklevel=2,
)
database = options["router"]
dbinfo = settings.DATABASES.get(database)
if dbinfo is None:
raise CommandError("Unknown database %s" % database)
engine = dbinfo.get("ENGINE")
user = password = database_name = database_host = database_port = ""
if engine == "mysql":
(user, password, database_name, database_host, database_port) = (
parse_mysql_cnf(dbinfo)
)
user = options["user"] or dbinfo.get("USER") or user
password = options["password"] or dbinfo.get("PASSWORD") or password
try:
database_name = dbinfo["TEST"]["NAME"]
except KeyError:
database_name = None
if database_name is None:
database_name = TEST_DATABASE_PREFIX + (
options["dbname"] or dbinfo.get("NAME")
)
if database_name is None or database_name == "":
raise CommandError(
"You need to specify DATABASE_NAME in your Django settings file."
)
database_host = dbinfo.get("HOST") or database_host
database_port = dbinfo.get("PORT") or database_port
verbosity = options["verbosity"]
if options["interactive"]:
confirm = input(
"""
You have requested to drop all test databases.
This will IRREVERSIBLY DESTROY
ALL data in the database "{db_name}"
and all cloned test databases generated via
the "--parallel" flag (these are sequentially
named "{db_name}_1", "{db_name}_2", etc.).
Are you sure you want to do this?
Type 'yes' to continue, or 'no' to cancel: """.format(db_name=database_name)
)
else:
confirm = "yes"
if confirm != "yes":
print("Reset cancelled.")
return
def get_database_names(formatter):
"""
Return a generator of all possible test database names.
e.g., 'test_foo', 'test_foo_1', test_foo_2', etc.
formatter: func returning a clone db name given the primary db name
and the clone's number, e.g., 'test_foo_1' for mysql/postgres, and
'test_foo_1..sqlite3' for sqlite (re: double dots, see comments).
"""
yield database_name
yield from (formatter(database_name, n) for n in count(1))
if engine in SQLITE_ENGINES:
# By default all sqlite test databases are created in memory.
# There will only be database files to delete if the developer has
# specified a test database name, which forces files to be written
# to disk.
logging.info("Unlinking %s databases" % engine)
def format_filename(name, number):
filename, ext = os.path.splitext(name)
# Since splitext() includes the dot in 'ext', the inclusion of
# the dot in the format string below is incorrect and creates a
# double dot. Django makes this mistake, so it must be
# replicated here. If fixed in Django, this code should be
# updated accordingly.
# Reference: https://code.djangoproject.com/ticket/32582
return "{}_{}.{}".format(filename, number, ext)
try:
for db_name in get_database_names(format_filename):
if not os.path.isfile(db_name):
break
logging.info('Unlinking database named "%s"' % db_name)
os.unlink(db_name)
except OSError:
return
elif engine in MYSQL_ENGINES:
import MySQLdb as Database
kwargs = {
"user": user,
"passwd": password,
}
if database_host.startswith("/"):
kwargs["unix_socket"] = database_host
else:
kwargs["host"] = database_host
if database_port:
kwargs["port"] = int(database_port)
connection = Database.connect(**kwargs)
cursor = connection.cursor()
for db_name in get_database_names("{}_{}".format):
exists_query = (
"SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
"WHERE SCHEMA_NAME='%s';" % db_name
)
row_count = cursor.execute(exists_query)
if row_count < 1:
break
drop_query = "DROP DATABASE IF EXISTS `%s`" % db_name
logging.info('Executing: "' + drop_query + '"')
cursor.execute(drop_query)
elif engine in POSTGRESQL_ENGINES:
has_psycopg3 = importlib.util.find_spec("psycopg")
if has_psycopg3:
import psycopg as Database # NOQA
else:
import psycopg2 as Database # NOQA
conn_params = {"dbname": "template1"}
if user:
conn_params["user"] = user
if password:
conn_params["password"] = password
if database_host:
conn_params["host"] = database_host
if database_port:
conn_params["port"] = database_port
connection = Database.connect(**conn_params)
if has_psycopg3:
connection.autocommit = True
else:
connection.set_isolation_level(0) # autocommit false
cursor = connection.cursor()
for db_name in get_database_names("{}_{}".format):
exists_query = (
"SELECT datname FROM pg_catalog.pg_database WHERE datname='%s';"
% db_name
)
try:
cursor.execute(exists_query)
# NOTE: Unlike MySQLdb, the psycopg2 cursor does not return the row
# count however both cursors provide it as a property
if cursor.rowcount < 1:
break
drop_query = 'DROP DATABASE IF EXISTS "%s";' % db_name
logging.info('Executing: "' + drop_query + '"')
cursor.execute(drop_query)
except Database.ProgrammingError as e:
logging.exception("Error: %s" % str(e))
return
else:
raise CommandError("Unknown database engine %s" % engine)
if verbosity >= 2 or options["interactive"]:
print("Reset successful.")

View File

@@ -0,0 +1,855 @@
# -*- coding: utf-8 -*-
"""
Title: Dumpscript management command
Project: Hardytools (queryset-refactor version)
Author: Will Hardy
Date: June 2008
Usage: python manage.py dumpscript appname > scripts/scriptname.py
$Revision: 217 $
Description:
Generates a Python script that will repopulate the database using objects.
The advantage of this approach is that it is easy to understand, and more
flexible than directly populating the database, or using XML.
* It also allows for new defaults to take effect and only transfers what is
needed.
* If a new database schema has a NEW ATTRIBUTE, it is simply not
populated (using a default value will make the transition smooth :)
* If a new database schema REMOVES AN ATTRIBUTE, it is simply ignored
and the data moves across safely (I'm assuming we don't want this
attribute anymore.
* Problems may only occur if there is a new model and is now a required
ForeignKey for an existing model. But this is easy to fix by editing the
populate script. Half of the job is already done as all ForeingKey
lookups occur though the locate_object() function in the generated script.
Improvements:
See TODOs and FIXMEs scattered throughout :-)
"""
import datetime
import sys
from django.apps import apps
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ObjectDoesNotExist
from django.core.management.base import BaseCommand
from django.db import router
from django.db.models import (
AutoField,
BooleanField,
DateField,
DateTimeField,
FileField,
ForeignKey,
)
from django.db.models.deletion import Collector
from django.utils import timezone
from django.utils.encoding import force_str, smart_str
from django_extensions.management.utils import signalcommand
def orm_item_locator(orm_obj):
"""
Is called every time an object that will not be exported is required.
Where orm_obj is the referred object.
We postpone the lookup to locate_object() which will be run on the generated script
"""
the_class = orm_obj._meta.object_name
original_class = the_class
pk_name = orm_obj._meta.pk.name
original_pk_name = pk_name
pk_value = getattr(orm_obj, pk_name)
while (
hasattr(pk_value, "_meta")
and hasattr(pk_value._meta, "pk")
and hasattr(pk_value._meta.pk, "name")
):
the_class = pk_value._meta.object_name
pk_name = pk_value._meta.pk.name
pk_value = getattr(pk_value, pk_name)
clean_dict = make_clean_dict(orm_obj.__dict__)
for key in clean_dict:
v = clean_dict[key]
if v is not None:
if isinstance(v, datetime.datetime):
if not timezone.is_aware(v):
v = timezone.make_aware(v)
clean_dict[key] = StrToCodeChanger(
'dateutil.parser.parse("%s")' % v.isoformat()
)
elif not isinstance(v, (str, int, float)):
clean_dict[key] = str("%s" % v)
output = """ importer.locate_object(%s, "%s", %s, "%s", %s, %s ) """ % (
original_class,
original_pk_name,
the_class,
pk_name,
pk_value,
clean_dict,
)
return output
class Command(BaseCommand):
help = "Dumps the data as a customised python script."
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument("appname", nargs="+")
parser.add_argument(
"--autofield",
action="store_false",
dest="skip_autofield",
default=True,
help="Include Autofields (like pk fields)",
)
@signalcommand
def handle(self, *args, **options):
app_labels = options["appname"]
# Get the models we want to export
models = get_models(app_labels)
# A dictionary is created to keep track of all the processed objects,
# so that foreign key references can be made using python variable names.
# This variable "context" will be passed around like the town bicycle.
context = {}
# Create a dumpscript object and let it format itself as a string
script = Script(
models=models,
context=context,
stdout=self.stdout,
stderr=self.stderr,
options=options,
)
self.stdout.write(str(script))
self.stdout.write("\n")
def get_models(app_labels):
"""
Get a list of models for the given app labels, with some exceptions.
TODO: If a required model is referenced, it should also be included.
Or at least discovered with a get_or_create() call.
"""
# These models are not to be outputted,
# e.g. because they can be generated automatically
# TODO: This should be "appname.modelname" string
EXCLUDED_MODELS = (ContentType,)
models = []
# If no app labels are given, return all
if not app_labels:
for app in apps.get_app_configs():
models += [
m
for m in apps.get_app_config(app.label).get_models()
if m not in EXCLUDED_MODELS
]
return models
# Get all relevant apps
for app_label in app_labels:
# If a specific model is mentioned, get only that model
if "." in app_label:
app_label, model_name = app_label.split(".", 1)
models.append(apps.get_model(app_label, model_name))
# Get all models for a given app
else:
models += [
m
for m in apps.get_app_config(app_label).get_models()
if m not in EXCLUDED_MODELS
]
return models
class Code:
"""
A snippet of python script.
This keeps track of import statements and can be output to a string.
In the future, other features such as custom indentation might be included
in this class.
"""
def __init__(self, indent=-1, stdout=None, stderr=None):
if not stdout:
stdout = sys.stdout
if not stderr:
stderr = sys.stderr
self.indent = indent
self.stdout = stdout
self.stderr = stderr
def __str__(self):
"""Return a string representation of this script."""
if self.imports:
self.stderr.write(repr(self.import_lines))
return flatten_blocks(
[""] + self.import_lines + [""] + self.lines, num_indents=self.indent
)
else:
return flatten_blocks(self.lines, num_indents=self.indent)
def get_import_lines(self):
"""Take the stored imports and converts them to lines"""
if self.imports:
return [
"from %s import %s" % (value, key)
for key, value in self.imports.items()
]
else:
return []
import_lines = property(get_import_lines)
class ModelCode(Code):
"""Produces a python script that can recreate data for a given model class."""
def __init__(self, model, context=None, stdout=None, stderr=None, options=None):
super().__init__(indent=0, stdout=stdout, stderr=stderr)
self.model = model
if context is None:
context = {}
self.context = context
self.options = options
self.instances = []
def get_imports(self):
"""
Return a dictionary of import statements, with the variable being
defined as the key.
"""
return {self.model.__name__: smart_str(self.model.__module__)}
imports = property(get_imports)
def get_lines(self):
"""
Return a list of lists or strings, representing the code body.
Each list is a block, each string is a statement.
"""
code = []
for counter, item in enumerate(self.model._default_manager.all()):
instance = InstanceCode(
instance=item,
id=counter + 1,
context=self.context,
stdout=self.stdout,
stderr=self.stderr,
options=self.options,
)
self.instances.append(instance)
if instance.waiting_list:
code += instance.lines
# After each instance has been processed, try again.
# This allows self referencing fields to work.
for instance in self.instances:
if instance.waiting_list:
code += instance.lines
return code
lines = property(get_lines)
class InstanceCode(Code):
"""Produces a python script that can recreate data for a given model instance."""
def __init__(
self, instance, id, context=None, stdout=None, stderr=None, options=None
):
"""We need the instance in question and an id"""
super().__init__(indent=0, stdout=stdout, stderr=stderr)
self.imports = {}
self.options = options
self.instance = instance
self.model = self.instance.__class__
if context is None:
context = {}
self.context = context
self.variable_name = "%s_%s" % (self.instance._meta.db_table, id)
self.skip_me = None
self.instantiated = False
self.waiting_list = list(self.model._meta.fields)
self.many_to_many_waiting_list = {}
for field in self.model._meta.many_to_many:
try:
if not field.remote_field.through._meta.auto_created:
continue
except AttributeError:
pass
self.many_to_many_waiting_list[field] = list(
getattr(self.instance, field.name).all()
)
def get_lines(self, force=False):
"""
Return a list of lists or strings, representing the code body.
Each list is a block, each string is a statement.
force (True or False): if an attribute object cannot be included,
it is usually skipped to be processed later. With 'force' set, there
will be no waiting: a get_or_create() call is written instead.
"""
code_lines = []
# Don't return anything if this is an instance that should be skipped
if self.skip():
return []
# Initialise our new object
# e.g. model_name_35 = Model()
code_lines += self.instantiate()
# Add each field
# e.g. model_name_35.field_one = 1034.91
# model_name_35.field_two = "text"
code_lines += self.get_waiting_list()
if force:
# TODO: Check that M2M are not affected
code_lines += self.get_waiting_list(force=force)
# Print the save command for our new object
# e.g. model_name_35.save()
if code_lines:
code_lines.append(
"%s = importer.save_or_locate(%s)\n"
% (self.variable_name, self.variable_name)
)
code_lines += self.get_many_to_many_lines(force=force)
return code_lines
lines = property(get_lines)
def skip(self):
"""
Determine whether or not this object should be skipped.
If this model instance is a parent of a single subclassed
instance, skip it. The subclassed instance will create this
parent instance for us.
TODO: Allow the user to force its creation?
"""
if self.skip_me is not None:
return self.skip_me
cls = self.instance.__class__
using = router.db_for_write(cls, instance=self.instance)
collector = Collector(using=using)
collector.collect([self.instance], collect_related=False)
sub_objects = sum([list(i) for i in collector.data.values()], [])
sub_objects_parents = [so._meta.parents for so in sub_objects]
if [self.model in p for p in sub_objects_parents].count(True) == 1:
# since this instance isn't explicitly created, it's variable name
# can't be referenced in the script, so record None in context dict
pk_name = self.instance._meta.pk.name
key = "%s_%s" % (self.model.__name__, getattr(self.instance, pk_name))
self.context[key] = None
self.skip_me = True
else:
self.skip_me = False
return self.skip_me
def instantiate(self):
"""Write lines for instantiation"""
# e.g. model_name_35 = Model()
code_lines = []
if not self.instantiated:
code_lines.append("%s = %s()" % (self.variable_name, self.model.__name__))
self.instantiated = True
# Store our variable name for future foreign key references
pk_name = self.instance._meta.pk.name
key = "%s_%s" % (self.model.__name__, getattr(self.instance, pk_name))
self.context[key] = self.variable_name
return code_lines
def get_waiting_list(self, force=False):
"""Add lines for any waiting fields that can be completed now."""
code_lines = []
skip_autofield = self.options["skip_autofield"]
# Process normal fields
for field in list(self.waiting_list):
try:
# Find the value, add the line, remove from waiting list and move on
value = get_attribute_value(
self.instance,
field,
self.context,
force=force,
skip_autofield=skip_autofield,
)
code_lines.append(
"%s.%s = %s" % (self.variable_name, field.name, value)
)
self.waiting_list.remove(field)
except SkipValue:
# Remove from the waiting list and move on
self.waiting_list.remove(field)
continue
except DoLater:
# Move on, maybe next time
continue
return code_lines
def get_many_to_many_lines(self, force=False):
"""Generate lines that define many to many relations for this instance."""
lines = []
for field, rel_items in self.many_to_many_waiting_list.items():
for rel_item in list(rel_items):
try:
pk_name = rel_item._meta.pk.name
key = "%s_%s" % (
rel_item.__class__.__name__,
getattr(rel_item, pk_name),
)
value = "%s" % self.context[key]
lines.append(
"%s.%s.add(%s)" % (self.variable_name, field.name, value)
)
self.many_to_many_waiting_list[field].remove(rel_item)
except KeyError:
if force:
item_locator = orm_item_locator(rel_item)
self.context["__extra_imports"][rel_item._meta.object_name] = (
rel_item.__module__
)
lines.append(
"%s.%s.add( %s )"
% (self.variable_name, field.name, item_locator)
)
self.many_to_many_waiting_list[field].remove(rel_item)
if lines:
lines.append("")
return lines
class Script(Code):
"""Produces a complete python script that can recreate data for the given apps."""
def __init__(self, models, context=None, stdout=None, stderr=None, options=None):
super().__init__(stdout=stdout, stderr=stderr)
self.imports = {}
self.models = models
if context is None:
context = {}
self.context = context
self.context["__avaliable_models"] = set(models)
self.context["__extra_imports"] = {}
self.options = options
def _queue_models(self, models, context):
"""
Work an an appropriate ordering for the models.
This isn't essential, but makes the script look nicer because
more instances can be defined on their first try.
"""
model_queue = []
number_remaining_models = len(models)
# Max number of cycles allowed before we call it an infinite loop.
MAX_CYCLES = number_remaining_models
allowed_cycles = MAX_CYCLES
while number_remaining_models > 0:
previous_number_remaining_models = number_remaining_models
model = models.pop(0)
# If the model is ready to be processed, add it to the list
if check_dependencies(model, model_queue, context["__avaliable_models"]):
model_class = ModelCode(
model=model,
context=context,
stdout=self.stdout,
stderr=self.stderr,
options=self.options,
)
model_queue.append(model_class)
# Otherwise put the model back at the end of the list
else:
models.append(model)
# Check for infinite loops.
# This means there is a cyclic foreign key structure
# That cannot be resolved by re-ordering
number_remaining_models = len(models)
if number_remaining_models == previous_number_remaining_models:
allowed_cycles -= 1
if allowed_cycles <= 0:
# Add remaining models, but do not remove them from the model list
missing_models = [
ModelCode(
model=m,
context=context,
stdout=self.stdout,
stderr=self.stderr,
options=self.options,
)
for m in models
]
model_queue += missing_models
# Replace the models with the model class objects
# (sure, this is a little bit of hackery)
models[:] = missing_models
break
else:
allowed_cycles = MAX_CYCLES
return model_queue
def get_lines(self):
"""
Return a list of lists or strings, representing the code body.
Each list is a block, each string is a statement.
"""
code = [self.FILE_HEADER.strip()]
# Queue and process the required models
for model_class in self._queue_models(self.models, context=self.context):
msg = "Processing model: %s.%s\n" % (
model_class.model.__module__,
model_class.model.__name__,
)
self.stderr.write(msg)
code.append(" # " + msg)
code.append(model_class.import_lines)
code.append("")
code.append(model_class.lines)
# Process left over foreign keys from cyclic models
for model in self.models:
msg = "Re-processing model: %s.%s\n" % (
model.model.__module__,
model.model.__name__,
)
self.stderr.write(msg)
code.append(" # " + msg)
for instance in model.instances:
if instance.waiting_list or instance.many_to_many_waiting_list:
code.append(instance.get_lines(force=True))
code.insert(1, " # Initial Imports")
code.insert(2, "")
for key, value in self.context["__extra_imports"].items():
code.insert(2, " from %s import %s" % (value, key))
return code
lines = property(get_lines)
# A user-friendly file header
FILE_HEADER = """
#!/usr/bin/env python
# This file has been automatically generated.
# Instead of changing it, create a file called import_helper.py
# and put there a class called ImportHelper(object) in it.
#
# This class will be specially cast so that instead of extending object,
# it will actually extend the class BasicImportHelper()
#
# That means you just have to overload the methods you want to
# change, leaving the other ones intact.
#
# Something that you might want to do is use transactions, for example.
#
# Also, don't forget to add the necessary Django imports.
#
# This file was generated with the following command:
# %s
#
# to restore it, run
# manage.py runscript module_name.this_script_name
#
# example: if manage.py is at ./manage.py
# and the script is at ./some_folder/some_script.py
# you must make sure ./some_folder/__init__.py exists
# and run ./manage.py runscript some_folder.some_script
import os, sys
from django.db import transaction
class BasicImportHelper:
def pre_import(self):
pass
@transaction.atomic
def run_import(self, import_data):
import_data()
def post_import(self):
pass
def locate_similar(self, current_object, search_data):
# You will probably want to call this method from save_or_locate()
# Example:
# new_obj = self.locate_similar(the_obj, {"national_id": the_obj.national_id } )
the_obj = current_object.__class__.objects.get(**search_data)
return the_obj
def locate_object(self, original_class, original_pk_name, the_class, pk_name, pk_value, obj_content):
# You may change this function to do specific lookup for specific objects
#
# original_class class of the django orm's object that needs to be located
# original_pk_name the primary key of original_class
# the_class parent class of original_class which contains obj_content
# pk_name the primary key of original_class
# pk_value value of the primary_key
# obj_content content of the object which was not exported.
#
# You should use obj_content to locate the object on the target db
#
# An example where original_class and the_class are different is
# when original_class is Farmer and the_class is Person. The table
# may refer to a Farmer but you will actually need to locate Person
# in order to instantiate that Farmer
#
# Example:
# if the_class == SurveyResultFormat or the_class == SurveyType or the_class == SurveyState:
# pk_name="name"
# pk_value=obj_content[pk_name]
# if the_class == StaffGroup:
# pk_value=8
search_data = { pk_name: pk_value }
the_obj = the_class.objects.get(**search_data)
#print(the_obj)
return the_obj
def save_or_locate(self, the_obj):
# Change this if you want to locate the object in the database
try:
the_obj.save()
except:
print("---------------")
print("Error saving the following object:")
print(the_obj.__class__)
print(" ")
print(the_obj.__dict__)
print(" ")
print(the_obj)
print(" ")
print("---------------")
raise
return the_obj
importer = None
try:
import import_helper
# We need this so ImportHelper can extend BasicImportHelper, although import_helper.py
# has no knowlodge of this class
importer = type("DynamicImportHelper", (import_helper.ImportHelper, BasicImportHelper ) , {} )()
except ImportError as e:
# From Python 3.3 we can check e.name - string match is for backward compatibility.
if 'import_helper' in str(e):
importer = BasicImportHelper()
else:
raise
import datetime
from decimal import Decimal
from django.contrib.contenttypes.models import ContentType
try:
import dateutil.parser
from dateutil.tz import tzoffset
except ImportError:
print("Please install python-dateutil")
sys.exit(os.EX_USAGE)
def run():
importer.pre_import()
importer.run_import(import_data)
importer.post_import()
def import_data():
""" % " ".join(sys.argv) # noqa: E501
# HELPER FUNCTIONS
# -------------------------------------------------------------------------------
def flatten_blocks(lines, num_indents=-1):
"""
Take a list (block) or string (statement) and flattens it into a string
with indentation.
"""
# The standard indent is four spaces
INDENTATION = " " * 4
if not lines:
return ""
# If this is a string, add the indentation and finish here
if isinstance(lines, str):
return INDENTATION * num_indents + lines
# If this is not a string, join the lines and recurse
return "\n".join([flatten_blocks(line, num_indents + 1) for line in lines])
def get_attribute_value(item, field, context, force=False, skip_autofield=True):
"""Get a string version of the given attribute's value, like repr() might."""
# Find the value of the field, catching any database issues
try:
value = getattr(item, field.name)
except ObjectDoesNotExist:
raise SkipValue(
"Could not find object for %s.%s, ignoring.\n"
% (item.__class__.__name__, field.name)
)
# AutoField: We don't include the auto fields, they'll be automatically recreated
if skip_autofield and isinstance(field, AutoField):
raise SkipValue()
# Some databases (eg MySQL) might store boolean values as 0/1,
# this needs to be cast as a bool
elif isinstance(field, BooleanField) and value is not None:
return repr(bool(value))
# Post file-storage-refactor, repr() on File/ImageFields no longer returns the path
elif isinstance(field, FileField):
return repr(force_str(value))
# ForeignKey fields, link directly using our stored python variable name
elif isinstance(field, ForeignKey) and value is not None:
# Special case for contenttype foreign keys: no need to output any
# content types in this script, as they can be generated again
# automatically.
# NB: Not sure if "is" will always work
if field.remote_field.model is ContentType:
return 'ContentType.objects.get(app_label="%s", model="%s")' % (
value.app_label,
value.model,
)
# Generate an identifier (key) for this foreign object
pk_name = value._meta.pk.name
key = "%s_%s" % (value.__class__.__name__, getattr(value, pk_name))
if key in context:
variable_name = context[key]
# If the context value is set to None, this should be skipped.
# This identifies models that have been skipped (inheritance)
if variable_name is None:
raise SkipValue()
# Return the variable name listed in the context
return "%s" % variable_name
elif value.__class__ not in context["__avaliable_models"] or force:
context["__extra_imports"][value._meta.object_name] = value.__module__
item_locator = orm_item_locator(value)
return item_locator
else:
raise DoLater("(FK) %s.%s\n" % (item.__class__.__name__, field.name))
elif isinstance(field, (DateField, DateTimeField)) and value is not None:
return 'dateutil.parser.parse("%s")' % value.isoformat()
# A normal field (e.g. a python built-in)
else:
return repr(value)
def make_clean_dict(the_dict):
if "_state" in the_dict:
clean_dict = the_dict.copy()
del clean_dict["_state"]
return clean_dict
return the_dict
def check_dependencies(model, model_queue, avaliable_models):
"""Check that all the depenedencies for this model are already in the queue."""
# A list of allowed links: existing fields, itself and the special case ContentType
allowed_links = [m.model.__name__ for m in model_queue] + [
model.__name__,
"ContentType",
]
# For each ForeignKey or ManyToMany field, check that a link is possible
for field in model._meta.fields:
if not field.remote_field:
continue
if field.remote_field.model.__name__ not in allowed_links:
if field.remote_field.model not in avaliable_models:
continue
return False
for field in model._meta.many_to_many:
if not field.remote_field:
continue
if field.remote_field.model.__name__ not in allowed_links:
return False
return True
# EXCEPTIONS
# -------------------------------------------------------------------------------
class SkipValue(Exception):
"""Value could not be parsed or should simply be skipped."""
class DoLater(Exception):
"""Value could not be parsed or should simply be skipped."""
class StrToCodeChanger:
def __init__(self, string):
self.repr = string
def __repr__(self):
return self.repr

View File

@@ -0,0 +1,212 @@
# -*- coding: utf-8 -*-
import sys
import csv
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import signalcommand
FORMATS = [
"address",
"emails",
"google",
"outlook",
"linkedin",
"vcard",
]
def full_name(**kwargs):
"""Return full name or username."""
first_name = kwargs.get("first_name")
last_name = kwargs.get("last_name")
name = " ".join(n for n in [first_name, last_name] if n)
if name:
return name
name = kwargs.get("name")
if name:
return name
username = kwargs.get("username")
if username:
return username
return ""
class Command(BaseCommand):
help = "Export user email address list in one of a number of formats."
args = "[output file]"
label = "filename to save to"
can_import_settings = True
encoding = "utf-8" # RED_FLAG: add as an option -DougN
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.UserModel = get_user_model()
def add_arguments(self, parser):
super().add_arguments(parser)
(
parser.add_argument(
"--group",
"-g",
action="store",
dest="group",
default=None,
help="Limit to users which are part of the supplied group name",
),
)
parser.add_argument(
"--format",
"-f",
action="store",
dest="format",
default=FORMATS[0],
help="output format. May be one of %s." % ", ".join(FORMATS),
)
def full_name(self, **kwargs):
return getattr(settings, "EXPORT_EMAILS_FULL_NAME_FUNC", full_name)(**kwargs)
@signalcommand
def handle(self, *args, **options):
if len(args) > 1:
raise CommandError("extra arguments supplied")
group = options["group"]
if group and not Group.objects.filter(name=group).count() == 1:
names = "', '".join(g["name"] for g in Group.objects.values("name"))
if names:
names = "'" + names + "'."
raise CommandError(
"Unknown group '" + group + "'. Valid group names are: " + names
)
UserModel = get_user_model()
order_by = getattr(
settings,
"EXPORT_EMAILS_ORDER_BY",
["last_name", "first_name", "username", "email"],
)
fields = getattr(
settings,
"EXPORT_EMAILS_FIELDS",
["last_name", "first_name", "username", "email"],
)
qs = UserModel.objects.all().order_by(*order_by)
if group:
qs = qs.filter(groups__name=group).distinct()
qs = qs.values(*fields)
getattr(self, options["format"])(qs)
def address(self, qs):
"""
Single entry per line in the format of:
"full name" <my@address.com>;
"""
self.stdout.write(
"\n".join(
'"%s" <%s>;' % (self.full_name(**ent), ent.get("email", ""))
for ent in qs
)
)
self.stdout.write("\n")
def emails(self, qs):
"""
Single entry with email only in the format of:
my@address.com,
"""
self.stdout.write(",\n".join(ent["email"] for ent in qs if ent.get("email")))
self.stdout.write("\n")
def google(self, qs):
"""CSV format suitable for importing into google GMail"""
csvf = csv.writer(sys.stdout)
csvf.writerow(["Name", "Email"])
for ent in qs:
csvf.writerow([self.full_name(**ent), ent.get("email", "")])
def linkedin(self, qs):
"""
CSV format suitable for importing into linkedin Groups.
perfect for pre-approving members of a linkedin group.
"""
csvf = csv.writer(sys.stdout)
csvf.writerow(["First Name", "Last Name", "Email"])
for ent in qs:
csvf.writerow(
[
ent.get("first_name", ""),
ent.get("last_name", ""),
ent.get("email", ""),
]
)
def outlook(self, qs):
"""CSV format suitable for importing into outlook"""
csvf = csv.writer(sys.stdout)
columns = [
"Name",
"E-mail Address",
"Notes",
"E-mail 2 Address",
"E-mail 3 Address",
"Mobile Phone",
"Pager",
"Company",
"Job Title",
"Home Phone",
"Home Phone 2",
"Home Fax",
"Home Address",
"Business Phone",
"Business Phone 2",
"Business Fax",
"Business Address",
"Other Phone",
"Other Fax",
"Other Address",
]
csvf.writerow(columns)
empty = [""] * (len(columns) - 2)
for ent in qs:
csvf.writerow([self.full_name(**ent), ent.get("email", "")] + empty)
def vcard(self, qs):
"""VCARD format."""
try:
import vobject
except ImportError:
print(
self.style.ERROR(
"Please install vobject to use the vcard export format."
)
)
sys.exit(1)
out = sys.stdout
for ent in qs:
card = vobject.vCard()
card.add("fn").value = self.full_name(**ent)
if ent.get("last_name") and ent.get("first_name"):
card.add("n").value = vobject.vcard.Name(
ent["last_name"], ent["first_name"]
)
else:
# fallback to fullname, if both first and lastname are not declared
card.add("n").value = vobject.vcard.Name(self.full_name(**ent))
if ent.get("email"):
emailpart = card.add("email")
emailpart.value = ent["email"]
emailpart.type_param = "INTERNET"
out.write(card.serialize())

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
import sys
from django.core.management.base import LabelCommand
from django.template import TemplateDoesNotExist, loader
from django_extensions.management.utils import signalcommand
class Command(LabelCommand):
help = "Finds the location of the given template by resolving its path"
args = "[template_path]"
label = "template path"
@signalcommand
def handle_label(self, template_path, **options):
try:
template = loader.get_template(template_path).template
except TemplateDoesNotExist:
sys.stderr.write("No template found\n")
else:
sys.stdout.write(self.style.SUCCESS((template.name)))

View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
import argparse
import string
import secrets
from typing import List
from django.core.management.base import BaseCommand
from django_extensions.management.utils import signalcommand
class Command(BaseCommand):
help = "Generates a simple new password that can be used for a user password. "
"Uses Pythons secrets module to generate passwords. Do not use this command to "
"generate your most secure passwords."
requires_system_checks: List[str] = []
def add_arguments(self, parser):
parser.add_argument(
"-l", "--length", nargs="?", type=int, default=16, help="Password length."
)
parser.add_argument(
"-c",
"--complex",
action=argparse.BooleanOptionalAction,
help="More complex alphabet, includes punctuation",
)
@signalcommand
def handle(self, *args, **options):
length = options["length"]
alphabet = string.ascii_letters + string.digits
if options["complex"]:
alphabet += string.punctuation
return "".join(secrets.choice(alphabet) for i in range(length))

View File

@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
from typing import List
from django.core.management.base import BaseCommand
from django.core.management.utils import get_random_secret_key
from django_extensions.management.utils import signalcommand
class Command(BaseCommand):
help = "Generates a new SECRET_KEY that can be used in a project settings file."
requires_system_checks: List[str] = []
@signalcommand
def handle(self, *args, **options):
return get_random_secret_key()

View File

@@ -0,0 +1,486 @@
# -*- coding: utf-8 -*-
import sys
import json
import os
import tempfile
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django.template import loader
from django_extensions.management.modelviz import ModelGraph, generate_dot
from django_extensions.management.utils import signalcommand
try:
import pygraphviz
HAS_PYGRAPHVIZ = True
except ImportError:
HAS_PYGRAPHVIZ = False
try:
try:
import pydotplus as pydot
except ImportError:
import pydot
HAS_PYDOT = True
except ImportError:
HAS_PYDOT = False
def retheme(graph_data, app_style={}):
if isinstance(app_style, str):
if os.path.exists(app_style):
try:
with open(app_style, "rt") as f:
app_style = json.load(f)
except Exception as e:
print(f"Invalid app style file {app_style}")
raise Exception(e)
else:
return graph_data
for gc in graph_data["graphs"]:
for g in gc:
if "name" in g:
for m in g["models"]:
app_name = g["app_name"]
if app_name in app_style:
m["style"] = app_style[app_name]
return graph_data
class Command(BaseCommand):
help = "Creates a GraphViz dot file for the specified app names."
" You can pass multiple app names and they will all be combined into a"
" single model. Output is usually directed to a dot file."
can_import_settings = True
def __init__(self, *args, **kwargs):
"""
Allow defaults for arguments to be set in settings.GRAPH_MODELS.
Each argument in self.arguments is a dict where the key is the
space-separated args and the value is our kwarg dict.
The default from settings is keyed as the long arg name with '--'
removed and any '-' replaced by '_'. For example, the default value for
--disable-fields can be set in settings.GRAPH_MODELS['disable_fields'].
"""
self.arguments = {
"--app-style": {
"action": "store",
"help": "Path to style json to configure the style per app",
"dest": "app-style",
"default": ".app-style.json",
},
"--pygraphviz": {
"action": "store_true",
"default": False,
"dest": "pygraphviz",
"help": "Output graph data as image using PyGraphViz.",
},
"--pydot": {
"action": "store_true",
"default": False,
"dest": "pydot",
"help": "Output graph data as image using PyDot(Plus).",
},
"--dot": {
"action": "store_true",
"default": False,
"dest": "dot",
"help": (
"Output graph data as raw DOT (graph description language) "
"text data."
),
},
"--json": {
"action": "store_true",
"default": False,
"dest": "json",
"help": "Output graph data as JSON",
},
"--disable-fields -d": {
"action": "store_true",
"default": False,
"dest": "disable_fields",
"help": "Do not show the class member fields",
},
"--disable-abstract-fields": {
"action": "store_true",
"default": False,
"dest": "disable_abstract_fields",
"help": "Do not show the class member fields that were inherited",
},
"--display-field-choices": {
"action": "store_true",
"default": False,
"dest": "display_field_choices",
"help": "Display choices instead of field type",
},
"--group-models -g": {
"action": "store_true",
"default": False,
"dest": "group_models",
"help": "Group models together respective to their application",
},
"--all-applications -a": {
"action": "store_true",
"default": False,
"dest": "all_applications",
"help": "Automatically include all applications from INSTALLED_APPS",
},
"--output -o": {
"action": "store",
"dest": "outputfile",
"help": (
"Render output file. Type of output dependend on file extensions. "
"Use png or jpg to render graph to image."
),
},
"--layout -l": {
"action": "store",
"dest": "layout",
"default": "dot",
"help": "Layout to be used by GraphViz for visualization. Layouts: "
"circo dot fdp neato nop nop1 nop2 twopi",
},
"--theme -t": {
"action": "store",
"dest": "theme",
"default": "django2018",
"help": "Theme to use. Supplied are 'original' and 'django2018'. "
"You can create your own by creating dot templates in "
"'django_extentions/graph_models/themename/' template directory.",
},
"--verbose-names -n": {
"action": "store_true",
"default": False,
"dest": "verbose_names",
"help": "Use verbose_name of models and fields",
},
"--language -L": {
"action": "store",
"dest": "language",
"help": "Specify language used for verbose_name localization",
},
"--exclude-columns -x": {
"action": "store",
"dest": "exclude_columns",
"help": "Exclude specific column(s) from the graph. "
"Can also load exclude list from file.",
},
"--exclude-models -X": {
"action": "store",
"dest": "exclude_models",
"help": "Exclude specific model(s) from the graph. Can also load "
"exclude list from file. Wildcards (*) are allowed.",
},
"--include-models -I": {
"action": "store",
"dest": "include_models",
"help": "Restrict the graph to specified models. "
"Wildcards (*) are allowed.",
},
"--inheritance -e": {
"action": "store_true",
"default": True,
"dest": "inheritance",
"help": "Include inheritance arrows (default)",
},
"--no-inheritance -E": {
"action": "store_false",
"default": False,
"dest": "inheritance",
"help": "Do not include inheritance arrows",
},
"--hide-relations-from-fields -R": {
"action": "store_false",
"default": True,
"dest": "relations_as_fields",
"help": "Do not show relations as fields in the graph.",
},
"--relation-fields-only": {
"action": "store",
"default": False,
"dest": "relation_fields_only",
"help": "Only display fields that are relevant for relations",
},
"--disable-sort-fields -S": {
"action": "store_false",
"default": True,
"dest": "sort_fields",
"help": "Do not sort fields",
},
"--hide-edge-labels": {
"action": "store_true",
"default": False,
"dest": "hide_edge_labels",
"help": "Do not show relations labels in the graph.",
},
"--arrow-shape": {
"action": "store",
"default": "dot",
"dest": "arrow_shape",
"choices": [
"box",
"crow",
"curve",
"icurve",
"diamond",
"dot",
"inv",
"none",
"normal",
"tee",
"vee",
],
"help": "Arrow shape to use for relations. Default is dot. "
"Available shapes: box, crow, curve, icurve, diamond, dot, inv, "
"none, normal, tee, vee.",
},
"--color-code-deletions": {
"action": "store_true",
"default": False,
"dest": "color_code_deletions",
"help": "Color the relations according to their on_delete setting, "
"where it is applicable. The colors are: red (CASCADE), "
"orange (SET_NULL), green (SET_DEFAULT), yellow (SET), "
"blue (PROTECT), grey (DO_NOTHING), and purple (RESTRICT).",
},
"--rankdir": {
"action": "store",
"default": "TB",
"dest": "rankdir",
"choices": ["TB", "BT", "LR", "RL"],
"help": "Set direction of graph layout. Supported directions: "
"TB, LR, BT and RL. Corresponding to directed graphs drawn from "
"top to bottom, from left to right, from bottom to top, and from "
"right to left, respectively. Default is TB.",
},
"--ordering": {
"action": "store",
"default": None,
"dest": "ordering",
"choices": ["in", "out"],
"help": "Controls how the edges are arranged. Supported orderings: "
'"in" (incoming relations first), "out" (outgoing relations first). '
"Default is None.",
},
}
defaults = getattr(settings, "GRAPH_MODELS", None)
if defaults:
for argument in self.arguments:
arg_split = argument.split(" ")
setting_opt = arg_split[0].lstrip("-").replace("-", "_")
if setting_opt in defaults:
self.arguments[argument]["default"] = defaults[setting_opt]
super().__init__(*args, **kwargs)
def add_arguments(self, parser):
"""Unpack self.arguments for parser.add_arguments."""
parser.add_argument("app_label", nargs="*")
for argument in self.arguments:
parser.add_argument(*argument.split(" "), **self.arguments[argument])
@signalcommand
def handle(self, *args, **options):
args = options["app_label"]
if not args and not options["all_applications"]:
default_app_labels = getattr(settings, "GRAPH_MODELS", {}).get("app_labels")
if default_app_labels:
args = default_app_labels
else:
raise CommandError("need one or more arguments for appname")
# Determine output format based on options, file extension, and library
# availability.
outputfile = options.get("outputfile") or ""
_, outputfile_ext = os.path.splitext(outputfile)
outputfile_ext = outputfile_ext.lower()
output_opts_names = ["pydot", "pygraphviz", "json", "dot"]
output_opts = {k: v for k, v in options.items() if k in output_opts_names}
output_opts_count = sum(output_opts.values())
if output_opts_count > 1:
raise CommandError(
"Only one of %s can be set."
% ", ".join(["--%s" % opt for opt in output_opts_names])
)
if output_opts_count == 1:
output = next(key for key, val in output_opts.items() if val)
elif not outputfile:
# When neither outputfile nor a output format option are set,
# default to printing .dot format to stdout. Kept for backward
# compatibility.
output = "dot"
elif outputfile_ext == ".dot":
output = "dot"
elif outputfile_ext == ".json":
output = "json"
elif HAS_PYGRAPHVIZ:
output = "pygraphviz"
elif HAS_PYDOT:
output = "pydot"
else:
raise CommandError(
"Neither pygraphviz nor pydotplus could be found to generate the image."
" To generate text output, use the --json or --dot options."
)
if options.get("rankdir") != "TB" and output not in [
"pydot",
"pygraphviz",
"dot",
]:
raise CommandError(
"--rankdir is not supported for the chosen output format"
)
if options.get("ordering") and output not in ["pydot", "pygraphviz", "dot"]:
raise CommandError(
"--ordering is not supported for the chosen output format"
)
# Consistency check: Abort if --pygraphviz or --pydot options are set
# but no outputfile is specified. Before 2.1.4 this silently fell back
# to printind .dot format to stdout.
if output in ["pydot", "pygraphviz"] and not outputfile:
raise CommandError(
"An output file (--output) must be specified when --pydot or "
"--pygraphviz are set."
)
cli_options = " ".join(sys.argv[2:])
graph_models = ModelGraph(args, cli_options=cli_options, **options)
graph_models.generate_graph_data()
if output == "json":
graph_data = graph_models.get_graph_data(as_json=True)
return self.render_output_json(graph_data, outputfile)
graph_data = graph_models.get_graph_data(as_json=False)
theme = options["theme"]
template_name = os.path.join(
"django_extensions", "graph_models", theme, "digraph.dot"
)
template = loader.get_template(template_name)
graph_data = retheme(graph_data, app_style=options["app-style"])
dotdata = generate_dot(graph_data, template=template)
if output == "pygraphviz":
return self.render_output_pygraphviz(dotdata, **options)
if output == "pydot":
return self.render_output_pydot(dotdata, **options)
self.print_output(dotdata, outputfile)
def print_output(self, dotdata, output_file=None):
"""Write model data to file or stdout in DOT (text) format."""
if isinstance(dotdata, bytes):
dotdata = dotdata.decode()
if output_file:
with open(output_file, "wt") as dot_output_f:
dot_output_f.write(dotdata)
else:
self.stdout.write(dotdata)
def render_output_json(self, graph_data, output_file=None):
"""Write model data to file or stdout in JSON format."""
if output_file:
with open(output_file, "wt") as json_output_f:
json.dump(graph_data, json_output_f)
else:
self.stdout.write(json.dumps(graph_data))
def render_output_pygraphviz(self, dotdata, **kwargs):
"""Render model data as image using pygraphviz."""
if not HAS_PYGRAPHVIZ:
raise CommandError("You need to install pygraphviz python module")
version = pygraphviz.__version__.rstrip("-svn")
try:
if tuple(int(v) for v in version.split(".")) < (0, 36):
# HACK around old/broken AGraph before version 0.36
# (ubuntu ships with this old version)
tmpfile = tempfile.NamedTemporaryFile()
tmpfile.write(dotdata)
tmpfile.seek(0)
dotdata = tmpfile.name
except ValueError:
pass
graph = pygraphviz.AGraph(dotdata)
graph.layout(prog=kwargs["layout"])
graph.draw(kwargs["outputfile"])
def render_output_pydot(self, dotdata, **kwargs):
"""Render model data as image using pydot."""
if not HAS_PYDOT:
raise CommandError("You need to install pydot python module")
graph = pydot.graph_from_dot_data(dotdata)
if not graph:
raise CommandError("pydot returned an error")
if isinstance(graph, (list, tuple)):
if len(graph) > 1:
sys.stderr.write(
"Found more then one graph, rendering only the first one.\n"
)
graph = graph[0]
output_file = kwargs["outputfile"]
formats = [
"bmp",
"canon",
"cmap",
"cmapx",
"cmapx_np",
"dot",
"dia",
"emf",
"em",
"fplus",
"eps",
"fig",
"gd",
"gd2",
"gif",
"gv",
"imap",
"imap_np",
"ismap",
"jpe",
"jpeg",
"jpg",
"metafile",
"pdf",
"pic",
"plain",
"plain-ext",
"png",
"pov",
"ps",
"ps2",
"svg",
"svgz",
"tif",
"tiff",
"tk",
"vml",
"vmlz",
"vrml",
"wbmp",
"webp",
"xdot",
]
ext = output_file[output_file.rfind(".") + 1 :]
format_ = ext if ext in formats else "raw"
graph.write(output_file, format=format_)

View File

@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
# Author: OmenApps. https://omenapps.com
import inspect
from django.apps import apps as django_apps
from django.conf import settings
from django.core.management.base import BaseCommand
from django.db import connection
from django_extensions.management.color import color_style
from django_extensions.management.utils import signalcommand
TAB = " "
HALFTAB = " "
class Command(BaseCommand):
"""A simple management command which lists model fields and methods."""
help = "List out the fields and methods for each model"
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--field-class",
action="store_true",
default=None,
help="show class name of field.",
)
parser.add_argument(
"--db-type",
action="store_true",
default=None,
help="show database column type of field.",
)
parser.add_argument(
"--signature",
action="store_true",
default=None,
help="show the signature of method.",
)
parser.add_argument(
"--all-methods",
action="store_true",
default=None,
help="list all methods, including private and default.",
)
parser.add_argument(
"--model",
nargs="?",
type=str,
default=None,
help="list the details for a single model. "
"Input should be in the form appname.Modelname",
)
def list_model_info(self, options):
style = color_style()
INFO = getattr(style, "INFO", lambda x: x)
WARN = getattr(style, "WARN", lambda x: x)
BOLD = getattr(style, "BOLD", lambda x: x)
FIELD_CLASS = (
True
if options.get("field_class", None) is not None
else getattr(settings, "MODEL_INFO_FIELD_CLASS", False)
)
DB_TYPE = (
True
if options.get("db_type", None) is not None
else getattr(settings, "MODEL_INFO_DB_TYPE", False)
)
SIGNATURE = (
True
if options.get("signature", None) is not None
else getattr(settings, "MODEL_INFO_SIGNATURE", False)
)
ALL_METHODS = (
True
if options.get("all_methods", None) is not None
else getattr(settings, "MODEL_INFO_ALL_METHODS", False)
)
MODEL = (
options.get("model")
if options.get("model", None) is not None
else getattr(settings, "MODEL_INFO_MODEL", False)
)
default_methods = [
"check",
"clean",
"clean_fields",
"date_error_message",
"delete",
"from_db",
"full_clean",
"get_absolute_url",
"get_deferred_fields",
"prepare_database_save",
"refresh_from_db",
"save",
"save_base",
"serializable_value",
"unique_error_message",
"validate_unique",
]
if MODEL:
model_list = [django_apps.get_model(MODEL)]
else:
model_list = sorted(
django_apps.get_models(),
key=lambda x: (x._meta.app_label, x._meta.object_name),
reverse=False,
)
for model in model_list:
self.stdout.write(
INFO(model._meta.app_label + "." + model._meta.object_name)
)
self.stdout.write(BOLD(HALFTAB + "Fields:"))
for field in model._meta.get_fields():
field_info = TAB + field.name + " -"
if FIELD_CLASS:
try:
field_info += " " + field.__class__.__name__
except TypeError:
field_info += WARN(" TypeError (field_class)")
except AttributeError:
field_info += WARN(" AttributeError (field_class)")
if FIELD_CLASS and DB_TYPE:
field_info += ","
if DB_TYPE:
try:
field_info += " " + field.db_type(connection=connection)
except TypeError:
field_info += WARN(" TypeError (db_type)")
except AttributeError:
field_info += WARN(" AttributeError (db_type)")
self.stdout.write(field_info)
if ALL_METHODS:
self.stdout.write(BOLD(HALFTAB + "Methods (all):"))
else:
self.stdout.write(BOLD(HALFTAB + "Methods (non-private/internal):"))
for method_name in dir(model):
try:
method = getattr(model, method_name)
if ALL_METHODS:
if callable(method) and not method_name[0].isupper():
if SIGNATURE:
signature = inspect.signature(method)
else:
signature = "()"
self.stdout.write(TAB + method_name + str(signature))
else:
if (
callable(method)
and not method_name.startswith("_")
and method_name not in default_methods
and not method_name[0].isupper()
):
if SIGNATURE:
signature = inspect.signature(method)
else:
signature = "()"
self.stdout.write(TAB + method_name + str(signature))
except AttributeError:
self.stdout.write(TAB + method_name + WARN(" - AttributeError"))
except ValueError:
self.stdout.write(
TAB
+ method_name
+ WARN(" - ValueError (could not identify signature)")
)
self.stdout.write("\n")
self.stdout.write(INFO("Total Models Listed: %d" % len(model_list)))
@signalcommand
def handle(self, *args, **options):
self.list_model_info(options)

View File

@@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
# Based on https://gist.github.com/voldmar/1264102
# and https://gist.github.com/runekaagaard/2eecf0a8367959dc634b7866694daf2c
import gc
import inspect
import weakref
from collections import defaultdict
import django
from django.apps import apps
from django.core.management.base import BaseCommand
from django.db.models.signals import (
ModelSignal,
pre_init,
post_init,
pre_save,
post_save,
pre_delete,
post_delete,
m2m_changed,
pre_migrate,
post_migrate,
)
from django.utils.encoding import force_str
MSG = "{module}.{name} #{line}{is_async}"
SIGNAL_NAMES = {
pre_init: "pre_init",
post_init: "post_init",
pre_save: "pre_save",
post_save: "post_save",
pre_delete: "pre_delete",
post_delete: "post_delete",
m2m_changed: "m2m_changed",
pre_migrate: "pre_migrate",
post_migrate: "post_migrate",
}
class Command(BaseCommand):
help = "List all signals by model and signal type"
def handle(self, *args, **options):
all_models = apps.get_models(include_auto_created=True, include_swapped=True)
model_lookup = {id(m): m for m in all_models}
signals = [obj for obj in gc.get_objects() if isinstance(obj, ModelSignal)]
models = defaultdict(lambda: defaultdict(list))
for signal in signals:
signal_name = SIGNAL_NAMES.get(signal, "unknown")
for receiver in signal.receivers:
if django.VERSION >= (5, 0):
lookup, receiver, is_async = receiver
else:
lookup, receiver = receiver
is_async = False
if isinstance(receiver, weakref.ReferenceType):
receiver = receiver()
if receiver is None:
continue
receiver_id, sender_id = lookup
model = model_lookup.get(sender_id, "_unknown_")
if model:
models[model][signal_name].append(
MSG.format(
name=receiver.__name__,
module=receiver.__module__,
is_async=" (async)" if is_async else "",
line=inspect.getsourcelines(receiver)[1],
path=inspect.getsourcefile(receiver),
)
)
output = []
for key in sorted(models.keys(), key=str):
verbose_name = force_str(key._meta.verbose_name)
output.append(
"{}.{} ({})".format(key.__module__, key.__name__, verbose_name)
)
for signal_name in sorted(models[key].keys()):
lines = models[key][signal_name]
output.append(" {}".format(signal_name))
for line in lines:
output.append(" {}".format(line))
return "\n".join(output)

View File

@@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-
import asyncio
import sys
try:
from aiosmtpd.controller import Controller
except ImportError:
raise ImportError("Please install 'aiosmtpd' library to use mail_debug command.")
from logging import getLogger
from typing import List
from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import setup_logger, signalcommand
logger = getLogger(__name__)
class CustomHandler:
async def handle_DATA(self, server, session, envelope):
"""Output will be sent to the module logger at INFO level."""
peer = session.peer
inheaders = 1
lines = envelope.content.decode("utf8", errors="replace").splitlines()
logger.info("---------- MESSAGE FOLLOWS ----------")
for line in lines:
# headers first
if inheaders and not line:
logger.info("X-Peer: %s" % peer[0])
inheaders = 0
logger.info(line)
logger.info("------------ END MESSAGE ------------")
return "250 OK"
class Command(BaseCommand):
help = "Starts a test mail server for development."
args = "[optional port number or ippaddr:port]"
requires_system_checks: List[str] = []
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument("addrport", nargs="?")
parser.add_argument(
"--output",
dest="output_file",
default=None,
help="Specifies an output file to send a copy of all messages "
"(not flushed immediately).",
)
parser.add_argument(
"--use-settings",
dest="use_settings",
action="store_true",
default=False,
help="Uses EMAIL_HOST and HOST_PORT from Django settings.",
)
@signalcommand
def handle(self, addrport="", *args, **options):
if not addrport:
if options["use_settings"]:
from django.conf import settings
addr = getattr(settings, "EMAIL_HOST", "")
port = str(getattr(settings, "EMAIL_PORT", "1025"))
else:
addr = ""
port = "1025"
else:
try:
addr, port = addrport.split(":")
except ValueError:
addr, port = "", addrport
if not addr:
addr = "127.0.0.1"
if not port.isdigit():
raise CommandError("%r is not a valid port number." % port)
else:
port = int(port)
# Add console handler
setup_logger(logger, stream=self.stdout, filename=options["output_file"])
def inner_run():
quit_command = (sys.platform == "win32") and "CTRL-BREAK" or "CONTROL-C"
print(
"Now accepting mail at %s:%s -- use %s to quit"
% (addr, port, quit_command)
)
handler = CustomHandler()
controller = Controller(handler, hostname=addr, port=port)
controller.start()
loop = asyncio.get_event_loop()
loop.run_forever()
try:
inner_run()
except KeyboardInterrupt:
pass

View File

@@ -0,0 +1,204 @@
# -*- coding: utf-8 -*-
import json
from operator import itemgetter
from pathlib import Path
from django.core.management import call_command
from django.core.management.base import BaseCommand, CommandError
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.recorder import MigrationRecorder
from django.utils import timezone
from django_extensions.management.utils import signalcommand
DEFAULT_FILENAME = "managestate.json"
DEFAULT_STATE = "default"
class Command(BaseCommand):
help = "Manage database state in the convenient way."
_applied_migrations = None
migrate_args: dict
migrate_options: dict
filename: str
verbosity: int
database: str
conn: BaseDatabaseWrapper
def add_arguments(self, parser):
parser.add_argument(
"action",
choices=("dump", "load"),
help="An action to do. "
"Dump action saves applied migrations to a file. "
"Load action applies migrations specified in a file.",
)
parser.add_argument(
"state",
nargs="?",
default=DEFAULT_STATE,
help="A name of a state. Usually a name of a git branch."
f'Defaults to "{DEFAULT_STATE}"',
)
parser.add_argument(
"-d",
"--database",
default=DEFAULT_DB_ALIAS,
help="Nominates a database to synchronize. "
f'Defaults to the "{DEFAULT_DB_ALIAS}" database.',
)
parser.add_argument(
"-f",
"--filename",
default=DEFAULT_FILENAME,
help=f'A file to write to. Defaults to "{DEFAULT_FILENAME}"',
)
# migrate command arguments
parser.add_argument(
"--noinput",
"--no-input",
action="store_false",
dest="interactive",
help='The argument for "migrate" command. '
"Tells Django to NOT prompt the user for input of any kind.",
)
parser.add_argument(
"--fake",
action="store_true",
help='The argument for "migrate" command. '
"Mark migrations as run without actually running them.",
)
parser.add_argument(
"--fake-initial",
action="store_true",
help='The argument for "migrate" command. '
"Detect if tables already exist and fake-apply initial migrations if so. "
"Make sure that the current database schema matches your initial migration "
"before using this flag. "
"Django will only check for an existing table name.",
)
parser.add_argument(
"--plan",
action="store_true",
help='The argument for "migrate" command. '
"Shows a list of the migration actions that will be performed.",
)
parser.add_argument(
"--run-syncdb",
action="store_true",
help='The argument for "migrate" command. '
"Creates tables for apps without migrations.",
)
parser.add_argument(
"--check",
action="store_true",
dest="check_unapplied",
help='The argument for "migrate" command. '
"Exits with a non-zero status if unapplied migrations exist.",
)
@signalcommand
def handle(self, action, database, filename, state, *args, **options):
self.migrate_args = args
self.migrate_options = options
self.verbosity = options["verbosity"]
self.conn = connections[database]
self.database = database
self.filename = filename
getattr(self, action)(state)
def dump(self, state: str):
"""Save applied migrations to a file."""
migrated_apps = self.get_migrated_apps()
migrated_apps.update(self.get_applied_migrations())
self.write({state: migrated_apps})
self.stdout.write(
self.style.SUCCESS(
f'Migrations for state "{state}" have been successfully '
f"saved to {self.filename}."
)
)
def load(self, state: str):
"""Apply migrations from a file."""
migrations = self.read().get(state)
if migrations is None:
raise CommandError(f"No such state saved: {state}")
kwargs = {
**self.migrate_options,
"database": self.database,
"verbosity": self.verbosity - 1 if self.verbosity > 1 else 0,
}
for app, migration in migrations.items():
if self.is_applied(app, migration):
continue
if self.verbosity > 1:
self.stdout.write(
self.style.WARNING(f'Applying migrations for "{app}"')
)
args = (app, migration, *self.migrate_args)
call_command("migrate", *args, **kwargs)
self.stdout.write(
self.style.SUCCESS(
f'Migrations for "{state}" have been successfully applied.'
)
)
def get_migrated_apps(self) -> dict:
"""Installed apps having migrations."""
apps = MigrationLoader(self.conn).migrated_apps
migrated_apps = dict.fromkeys(apps, "zero")
if self.verbosity > 1:
self.stdout.write(
"Apps having migrations: " + ", ".join(sorted(migrated_apps))
)
return migrated_apps
def get_applied_migrations(self) -> dict:
"""Installed apps with last applied migrations."""
if self._applied_migrations:
return self._applied_migrations
migrations = MigrationRecorder(self.conn).applied_migrations()
last_applied = sorted(migrations.keys(), key=itemgetter(1))
self._applied_migrations = dict(last_applied)
return self._applied_migrations
def is_applied(self, app: str, migration: str) -> bool:
"""Check whether a migration for an app is applied or not."""
applied = self.get_applied_migrations().get(app)
if applied == migration:
if self.verbosity > 1:
self.stdout.write(
self.style.WARNING(f'Migrations for "{app}" are already applied.')
)
return True
return False
def read(self) -> dict:
"""Get saved state from the file."""
path = Path(self.filename)
if not path.exists() or not path.is_file():
raise CommandError(f"No such file: {self.filename}")
with open(self.filename) as file:
return json.load(file)
def write(self, data: dict):
"""Write new data to the file using existent one."""
try:
saved = self.read()
except CommandError:
saved = {}
saved.update(data, updated_at=str(timezone.now()))
with open(self.filename, "w") as file:
json.dump(saved, file, indent=2, sort_keys=True)

View File

@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
from django.apps import apps
from django.contrib.contenttypes.fields import GenericForeignKey
from django.core.management import BaseCommand
from django.db import transaction
from django_extensions.management.utils import signalcommand
def get_model_to_deduplicate():
models = apps.get_models()
iterator = 1
for model in models:
print("%s. %s" % (iterator, model.__name__))
iterator += 1
model_choice = int(
input("Enter the number of the model you would like to de-duplicate:")
)
model_to_deduplicate = models[model_choice - 1]
return model_to_deduplicate
def get_field_names(model):
fields = [field.name for field in model._meta.get_fields()]
iterator = 1
for field in fields:
print("%s. %s" % (iterator, field))
iterator += 1
validated = False
while not validated:
first_field = int(
input(
"Enter the number of the (first) field you would like to de-duplicate."
)
)
if first_field in range(1, iterator):
validated = True
else:
print("Invalid input. Please try again.")
fields_to_deduplicate = [fields[first_field - 1]]
done = False
while not done:
available_fields = [f for f in fields if f not in fields_to_deduplicate]
iterator = 1
for field in available_fields:
print("%s. %s" % (iterator, field))
iterator += 1
print("C. Done adding fields.")
validated = False
while not validated:
print("You are currently deduplicating on the following fields:")
print("\n".join(fields_to_deduplicate) + "\n")
additional_field = input("""
Enter the number of the field you would like to de-duplicate.
If you have entered all fields, enter C to continue.
""")
if additional_field == "C":
done = True
validated = True
elif int(additional_field) in list(range(1, len(available_fields) + 1)):
fields_to_deduplicate += [available_fields[int(additional_field) - 1]]
validated = True
else:
print("Invalid input. Please try again.")
return fields_to_deduplicate
def keep_first_or_last_instance():
while True:
first_or_last = input("""
Do you want to keep the first or last duplicate instance?
Enter "first" or "last" to continue.
""")
if first_or_last in ["first", "last"]:
return first_or_last
def get_generic_fields():
"""Return a list of all GenericForeignKeys in all models."""
generic_fields = []
for model in apps.get_models():
for field_name, field in model.__dict__.items():
if isinstance(field, GenericForeignKey):
generic_fields.append(field)
return generic_fields
class Command(BaseCommand):
help = """
Removes duplicate model instances based on a specified
model and field name(s).
Makes sure that any OneToOne, ForeignKey, or ManyToMany relationships
attached to a deleted model(s) get reattached to the remaining model.
Based on the following:
https://djangosnippets.org/snippets/2283/
https://stackoverflow.com/a/41291137/2532070
https://gist.github.com/edelvalle/01886b6f79ba0c4dce66
"""
@signalcommand
def handle(self, *args, **options):
model = get_model_to_deduplicate()
field_names = get_field_names(model)
first_or_last = keep_first_or_last_instance()
total_deleted_objects_count = 0
for instance in model.objects.all():
kwargs = {}
for field_name in field_names:
instance_field_value = instance.__getattribute__(field_name)
kwargs.update({field_name: instance_field_value})
try:
model.objects.get(**kwargs)
except model.MultipleObjectsReturned:
instances = model.objects.filter(**kwargs)
if first_or_last == "first":
primary_object = instances.first()
alias_objects = instances.exclude(pk=primary_object.pk)
elif first_or_last == "last":
primary_object = instances.last()
alias_objects = instances.exclude(pk=primary_object.pk)
primary_object, deleted_objects, deleted_objects_count = (
self.merge_model_instances(primary_object, alias_objects)
)
total_deleted_objects_count += deleted_objects_count
print(
"Successfully deleted {} model instances.".format(
total_deleted_objects_count
)
)
@transaction.atomic()
def merge_model_instances(self, primary_object, alias_objects):
"""
Merge several model instances into one, the `primary_object`.
Use this function to merge model objects and migrate all of the related
fields from the alias objects the primary object.
"""
generic_fields = get_generic_fields()
# get related fields
related_fields = list(
filter(lambda x: x.is_relation is True, primary_object._meta.get_fields())
)
many_to_many_fields = list(
filter(lambda x: x.many_to_many is True, related_fields)
)
related_fields = list(filter(lambda x: x.many_to_many is False, related_fields))
# Loop through all alias objects and migrate their references to the
# primary object
deleted_objects = []
deleted_objects_count = 0
for alias_object in alias_objects:
# Migrate all foreign key references from alias object to primary
# object.
for many_to_many_field in many_to_many_fields:
alias_varname = many_to_many_field.name
related_objects = getattr(alias_object, alias_varname)
for obj in related_objects.all():
try:
# Handle regular M2M relationships.
getattr(alias_object, alias_varname).remove(obj)
getattr(primary_object, alias_varname).add(obj)
except AttributeError:
# Handle M2M relationships with a 'through' model.
# This does not delete the 'through model.
# TODO: Allow the user to delete a duplicate 'through' model.
through_model = getattr(alias_object, alias_varname).through
kwargs = {
many_to_many_field.m2m_reverse_field_name(): obj,
many_to_many_field.m2m_field_name(): alias_object,
}
through_model_instances = through_model.objects.filter(**kwargs)
for instance in through_model_instances:
# Re-attach the through model to the primary_object
setattr(
instance,
many_to_many_field.m2m_field_name(),
primary_object,
)
instance.save()
# TODO: Here, try to delete duplicate instances that are
# disallowed by a unique_together constraint
for related_field in related_fields:
if related_field.one_to_many:
alias_varname = related_field.get_accessor_name()
related_objects = getattr(alias_object, alias_varname)
for obj in related_objects.all():
field_name = related_field.field.name
setattr(obj, field_name, primary_object)
obj.save()
elif related_field.one_to_one or related_field.many_to_one:
alias_varname = related_field.name
related_object = getattr(alias_object, alias_varname)
primary_related_object = getattr(primary_object, alias_varname)
if primary_related_object is None:
setattr(primary_object, alias_varname, related_object)
primary_object.save()
elif related_field.one_to_one:
self.stdout.write(
"Deleted {} with id {}\n".format(
related_object, related_object.id
)
)
related_object.delete()
for field in generic_fields:
filter_kwargs = {}
filter_kwargs[field.fk_field] = alias_object._get_pk_val()
filter_kwargs[field.ct_field] = field.get_content_type(alias_object)
related_objects = field.model.objects.filter(**filter_kwargs)
for generic_related_object in related_objects:
setattr(generic_related_object, field.name, primary_object)
generic_related_object.save()
if alias_object.id:
deleted_objects += [alias_object]
self.stdout.write(
"Deleted {} with id {}\n".format(alias_object, alias_object.id)
)
alias_object.delete()
deleted_objects_count += 1
return primary_object, deleted_objects, deleted_objects_count

View File

@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
import os
import re
from django.conf import settings
from django.core.management.base import BaseCommand
from django_extensions.compat import get_template_setting
from django_extensions.management.utils import signalcommand
ANNOTATION_RE = re.compile(
r"\{?#[\s]*?(TODO|FIXME|BUG|HACK|WARNING|NOTE|XXX)[\s:]?(.+)"
)
ANNOTATION_END_RE = re.compile(r"(.*)#\}(.*)")
class Command(BaseCommand):
help = "Show all annotations like TODO, FIXME, BUG, HACK, WARNING, NOTE or XXX "
"in your py and HTML files."
label = "annotation tag (TODO, FIXME, BUG, HACK, WARNING, NOTE, XXX)"
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--tag", dest="tag", help="Search for specific tags only", action="append"
)
@signalcommand
def handle(self, *args, **options):
# don't add django internal code
apps = [
app.replace(".", "/")
for app in filter(
lambda app: not app.startswith("django.contrib"),
settings.INSTALLED_APPS,
)
]
template_dirs = get_template_setting("DIRS", [])
base_dir = getattr(settings, "BASE_DIR")
if template_dirs:
apps += template_dirs
for app_dir in apps:
if base_dir:
app_dir = os.path.join(base_dir, app_dir)
for top, dirs, files in os.walk(app_dir):
for fn in files:
if os.path.splitext(fn)[1] in (".py", ".html"):
fpath = os.path.join(top, fn)
annotation_lines = []
with open(fpath, "r") as fd:
i = 0
for line in fd.readlines():
i += 1
if ANNOTATION_RE.search(line):
tag, msg = ANNOTATION_RE.findall(line)[0]
if options["tag"]:
if tag not in map(
str.upper, map(str, options["tag"])
):
break
if ANNOTATION_END_RE.search(msg.strip()):
msg = ANNOTATION_END_RE.findall(msg.strip())[0][
0
]
annotation_lines.append(
"[%3s] %-5s %s" % (i, tag, msg.strip())
)
if annotation_lines:
self.stdout.write("%s:" % fpath)
for annotation in annotation_lines:
self.stdout.write(" * %s" % annotation)
self.stdout.write("")

View File

@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
"""
print_settings
==============
Django command similar to 'diffsettings' but shows all active Django settings.
"""
import fnmatch
import json
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import signalcommand
class Command(BaseCommand):
help = "Print the active Django settings."
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"setting", nargs="*", help="Specifies setting to be printed."
)
parser.add_argument(
"-f",
"--fail",
action="store_true",
dest="fail",
help="Fail if invalid setting name is given.",
)
parser.add_argument(
"--format", default="simple", dest="format", help="Specifies output format."
)
parser.add_argument(
"--indent",
default=4,
dest="indent",
type=int,
help="Specifies indent level for JSON and YAML",
)
@signalcommand
def handle(self, *args, **options):
setting_names = options["setting"]
settings_dct = {k: getattr(settings, k) for k in dir(settings) if k.isupper()}
if setting_names:
settings_dct = {
key: value
for key, value in settings_dct.items()
if any(
fnmatch.fnmatchcase(key, setting_name)
for setting_name in setting_names
)
}
if options["fail"]:
for setting_name in setting_names:
if not any(
fnmatch.fnmatchcase(key, setting_name)
for key in settings_dct.keys()
):
raise CommandError("%s not found in settings." % setting_name)
output_format = options["format"]
indent = options["indent"]
if output_format == "json":
print(json.dumps(settings_dct, indent=indent))
elif output_format == "yaml":
import yaml # requires PyYAML
print(yaml.dump(settings_dct, indent=indent))
elif output_format == "pprint":
from pprint import pprint
pprint(settings_dct)
elif output_format == "text":
for key, value in settings_dct.items():
print("%s = %s" % (key, value))
elif output_format == "value":
for value in settings_dct.values():
print(value)
else:
for key, value in settings_dct.items():
print("%-40s = %r" % (key, value))

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
import importlib
from django.conf import settings
from django.contrib.auth import load_backend, BACKEND_SESSION_KEY, SESSION_KEY
from django.contrib.sessions.backends.base import VALID_KEY_CHARS
from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import signalcommand
class Command(BaseCommand):
help = (
"print the user information for the provided session key. "
"this is very helpful when trying to track down the person who "
"experienced a site crash."
)
def add_arguments(self, parser):
parser.add_argument("session_id", nargs="+", type=str, help="user session id")
@signalcommand
def handle(self, *args, **options):
key = options["session_id"][0]
if not set(key).issubset(set(VALID_KEY_CHARS)):
raise CommandError("malformed session key")
engine = importlib.import_module(settings.SESSION_ENGINE)
if not engine.SessionStore().exists(key):
print("Session Key does not exist. Expired?")
return
session = engine.SessionStore(key)
data = session.load()
print("Session to Expire: %s" % session.get_expiry_date())
print("Raw Data: %s" % data)
uid = data.get(SESSION_KEY, None)
backend_path = data.get(BACKEND_SESSION_KEY, None)
if backend_path is None:
print("No authentication backend associated with session")
return
if uid is None:
print("No user associated with session")
return
print("User id: %s" % uid)
backend = load_backend(backend_path)
user = backend.get_user(user_id=uid)
if user is None:
print("No user associated with that id.")
return
# use django standrd api for reporting
print("full name: %s" % user.get_full_name())
print("short name: %s" % user.get_short_name())
print("username: %s" % user.get_username())
if hasattr(user, "email"):
print("email: %s" % user.email)

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
from django.core.management.base import BaseCommand
from django_extensions.management.utils import signalcommand
class DjangoExtensionsTestException(Exception):
pass
class Command(BaseCommand):
help = (
"Raises a test Exception named DjangoExtensionsTestException. "
"Useful for debugging integration with error reporters like Sentry."
)
@signalcommand
def handle(self, *args, **options):
message = (
"This is a test exception via the "
"django-extensions raise_test_exception management command."
)
raise DjangoExtensionsTestException(message)

View File

@@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
"""
reset_db command
originally from https://www.djangosnippets.org/snippets/828/ by dnordberg
"""
import importlib.util
import os
import logging
import warnings
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django.db import DEFAULT_DB_ALIAS
from django_extensions.settings import SQLITE_ENGINES, POSTGRESQL_ENGINES, MYSQL_ENGINES
from django_extensions.management.mysql import parse_mysql_cnf
from django_extensions.management.utils import signalcommand
from django_extensions.utils.deprecation import RemovedInNextVersionWarning
class Command(BaseCommand):
help = "Resets the database for this project."
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--noinput",
"--no-input",
action="store_false",
dest="interactive",
default=True,
help="Tells Django to NOT prompt the user for input of any kind.",
)
parser.add_argument(
"--no-utf8",
action="store_true",
dest="no_utf8_support",
default=False,
help="Tells Django to not create a UTF-8 charset database",
)
parser.add_argument(
"-U",
"--user",
action="store",
dest="user",
default=None,
help="Use another user for the database than defined in settings.py",
)
parser.add_argument(
"-O",
"--owner",
action="store",
dest="owner",
default=None,
help="Use another owner for creating the database than the user defined "
"in settings or via --user",
)
parser.add_argument(
"-P",
"--password",
action="store",
dest="password",
default=None,
help="Use another password for the database than defined in settings.py",
)
parser.add_argument(
"-D",
"--dbname",
action="store",
dest="dbname",
default=None,
help="Use another database name than defined in settings.py",
)
parser.add_argument(
"-R",
"--router",
action="store",
dest="router",
default=DEFAULT_DB_ALIAS,
help="Use this router-database other than defined in settings.py",
)
parser.add_argument(
"--database",
default=DEFAULT_DB_ALIAS,
help='Nominates a database to run command for. Defaults to the "%s".'
% DEFAULT_DB_ALIAS,
)
parser.add_argument(
"-c",
"--close-sessions",
action="store_true",
dest="close_sessions",
default=False,
help="Close database connections before dropping database "
"(currently works on PostgreSQL only)",
)
@signalcommand
def handle(self, *args, **options):
"""
Reset the database for this project.
Note: Transaction wrappers are in reverse as a work around for
autocommit, anybody know how to do this the right way?
"""
database = options["database"]
if options["router"] != DEFAULT_DB_ALIAS:
warnings.warn(
"--router is deprecated. You should use --database.",
RemovedInNextVersionWarning,
stacklevel=2,
)
database = options["router"]
dbinfo = settings.DATABASES.get(database)
if dbinfo is None:
raise CommandError("Unknown database %s" % database)
engine = dbinfo.get("ENGINE")
user = password = database_name = database_host = database_port = ""
if engine == "mysql":
(user, password, database_name, database_host, database_port) = (
parse_mysql_cnf(dbinfo)
)
user = options["user"] or dbinfo.get("USER") or user
password = options["password"] or dbinfo.get("PASSWORD") or password
owner = options["owner"] or user
database_name = options["dbname"] or dbinfo.get("NAME") or database_name
if database_name == "":
raise CommandError(
"You need to specify DATABASE_NAME in your Django settings file."
)
database_host = dbinfo.get("HOST") or database_host
database_port = dbinfo.get("PORT") or database_port
verbosity = options["verbosity"]
if options["interactive"]:
confirm = input(
"""
You have requested a database reset.
This will IRREVERSIBLY DESTROY
ALL data in the database "%s".
Are you sure you want to do this?
Type 'yes' to continue, or 'no' to cancel: """
% (database_name,)
)
else:
confirm = "yes"
if confirm != "yes":
print("Reset cancelled.")
return
if engine in SQLITE_ENGINES:
try:
logging.info("Unlinking %s database", engine)
os.unlink(database_name)
except OSError:
pass
elif engine in MYSQL_ENGINES:
import MySQLdb as Database
kwargs = {
"user": user,
"passwd": password,
}
if database_host.startswith("/"):
kwargs["unix_socket"] = database_host
else:
kwargs["host"] = database_host
if database_port:
kwargs["port"] = int(database_port)
connection = Database.connect(**kwargs)
drop_query = "DROP DATABASE IF EXISTS `%s`" % database_name
utf8_support = "" if options["no_utf8_support"] else "CHARACTER SET utf8"
create_query = "CREATE DATABASE `%s` %s" % (database_name, utf8_support)
logging.info('Executing... "%s"', drop_query)
connection.query(drop_query)
logging.info('Executing... "%s"', create_query)
connection.query(create_query.strip())
elif engine in POSTGRESQL_ENGINES:
has_psycopg3 = importlib.util.find_spec("psycopg")
if has_psycopg3:
import psycopg as Database # NOQA
else:
import psycopg2 as Database # NOQA
conn_params = {"dbname": "template1"}
if user:
conn_params["user"] = user
if password:
conn_params["password"] = password
if database_host:
conn_params["host"] = database_host
if database_port:
conn_params["port"] = database_port
connection = Database.connect(**conn_params)
if has_psycopg3:
connection.autocommit = True
else:
connection.set_isolation_level(0) # autocommit false
cursor = connection.cursor()
if options["close_sessions"]:
close_sessions_query = (
"""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '%s';
"""
% database_name
)
logging.info('Executing... "%s"', close_sessions_query.strip())
try:
cursor.execute(close_sessions_query)
except Database.ProgrammingError as e:
logging.exception("Error: %s", str(e))
drop_query = 'DROP DATABASE "%s";' % database_name
logging.info('Executing... "%s"', drop_query)
try:
cursor.execute(drop_query)
except Database.ProgrammingError as e:
logging.exception("Error: %s", str(e))
create_query = 'CREATE DATABASE "%s"' % database_name
if owner:
create_query += ' WITH OWNER = "%s" ' % owner
create_query += " ENCODING = 'UTF8'"
if settings.DEFAULT_TABLESPACE:
create_query += " TABLESPACE = %s;" % settings.DEFAULT_TABLESPACE
else:
create_query += ";"
logging.info('Executing... "%s"', create_query)
cursor.execute(create_query)
else:
raise CommandError("Unknown database engine %s" % engine)
if verbosity >= 2 or options["interactive"]:
print("Reset successful.")

View File

@@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
"""
Recreates the public schema for current database (PostgreSQL only).
Useful for Docker environments where you need to reset database
schema while there are active connections.
"""
import warnings
from django.core.management import BaseCommand, CommandError
from django.db import DEFAULT_DB_ALIAS
from django.db import connections
from django.conf import settings
from django_extensions.settings import POSTGRESQL_ENGINES
from django_extensions.utils.deprecation import RemovedInNextVersionWarning
class Command(BaseCommand):
"""`reset_schema` command implementation."""
help = "Recreates the public schema for this project."
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--noinput",
"--no-input",
action="store_false",
dest="interactive",
default=True,
help="Tells Django to NOT prompt the user for input of any kind.",
)
parser.add_argument(
"-R",
"--router",
action="store",
dest="router",
default=DEFAULT_DB_ALIAS,
help="Use this router-database instead of the one defined in settings.py",
)
parser.add_argument(
"--database",
default=DEFAULT_DB_ALIAS,
help='Nominates a database to run command for. Defaults to the "%s".'
% DEFAULT_DB_ALIAS,
)
parser.add_argument(
"-S",
"--schema",
action="store",
dest="schema",
default="public",
help='Drop this schema instead of "public"',
)
def handle(self, *args, **options):
database = options["database"]
if options["router"] != DEFAULT_DB_ALIAS:
warnings.warn(
"--router is deprecated. You should use --database.",
RemovedInNextVersionWarning,
stacklevel=2,
)
database = options["router"]
dbinfo = settings.DATABASES.get(database)
if dbinfo is None:
raise CommandError("Unknown database %s" % database)
engine = dbinfo.get("ENGINE")
if engine not in POSTGRESQL_ENGINES:
raise CommandError(
"This command can be used only with PostgreSQL databases."
)
database_name = dbinfo["NAME"]
schema = options["schema"]
if options["interactive"]:
confirm = input(
"""
You have requested a database schema reset.
This will IRREVERSIBLY DESTROY ALL data
in the "{}" schema of database "{}".
Are you sure you want to do this?
Type 'yes' to continue, or 'no' to cancel: """.format(schema, database_name)
)
else:
confirm = "yes"
if confirm != "yes":
print("Reset cancelled.")
return
with connections[database].cursor() as cursor:
cursor.execute("DROP SCHEMA {} CASCADE".format(schema))
cursor.execute("CREATE SCHEMA {}".format(schema))

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
import logging
from django.core.management.base import BaseCommand
from django_extensions.management.jobs import get_job, print_jobs
from django_extensions.management.utils import setup_logger, signalcommand
logger = logging.getLogger(__name__)
class Command(BaseCommand):
help = "Run a single maintenance job."
missing_args_message = "test"
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument("app_name", nargs="?")
parser.add_argument("job_name", nargs="?")
parser.add_argument(
"--list",
"-l",
action="store_true",
dest="list_jobs",
default=False,
help="List all jobs with their description",
)
def runjob(self, app_name, job_name, options):
verbosity = options["verbosity"]
if verbosity > 1:
logger.info("Executing job: %s (app: %s)", job_name, app_name)
try:
job = get_job(app_name, job_name)
except KeyError:
if app_name:
logger.error(
"Error: Job %s for applabel %s not found", job_name, app_name
)
else:
logger.error("Error: Job %s not found", job_name)
logger.info("Use -l option to view all the available jobs")
return
try:
job().execute()
except Exception:
logger.exception("ERROR OCCURED IN JOB: %s (APP: %s)", job_name, app_name)
@signalcommand
def handle(self, *args, **options):
app_name = options["app_name"]
job_name = options["job_name"]
# hack since we are using job_name nargs='?' for -l to work
if app_name and not job_name:
job_name = app_name
app_name = None
setup_logger(logger, self.stdout)
if options["list_jobs"]:
print_jobs(only_scheduled=False, show_when=True, show_appname=True)
else:
self.runjob(app_name, job_name, options)

View File

@@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-
import logging
from django.apps import apps
from django.core.management.base import BaseCommand
from django_extensions.management.jobs import get_jobs, print_jobs
from django_extensions.management.utils import setup_logger, signalcommand
logger = logging.getLogger(__name__)
class Command(BaseCommand):
help = "Runs scheduled maintenance jobs."
when_options = [
"minutely",
"quarter_hourly",
"hourly",
"daily",
"weekly",
"monthly",
"yearly",
]
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"when", nargs="?", help="options: %s" % ", ".join(self.when_options)
)
parser.add_argument(
"--list",
"-l",
action="store_true",
dest="list_jobs",
default=False,
help="List all jobs with their description",
)
def usage_msg(self):
print("%s Please specify: %s" % (self.help, ", ".join(self.when_options)))
def runjobs(self, when, options):
verbosity = options["verbosity"]
jobs = get_jobs(when, only_scheduled=True)
for app_name, job_name in sorted(jobs.keys()):
job = jobs[(app_name, job_name)]
if verbosity > 1:
logger.info("Executing %s job: %s (app: %s)", when, job_name, app_name)
try:
job().execute()
except Exception:
logger.exception(
"ERROR OCCURED IN JOB: %s (APP: %s)", job_name, app_name
)
def runjobs_by_signals(self, when, options):
"""Run jobs from the signals"""
# Thanks for Ian Holsman for the idea and code
from django_extensions.management import signals
from django.conf import settings
verbosity = options["verbosity"]
for app_name in settings.INSTALLED_APPS:
try:
__import__(app_name + ".management", "", "", [""])
except ImportError:
pass
for app in (
app.models_module for app in apps.get_app_configs() if app.models_module
):
if verbosity > 1:
app_name = ".".join(app.__name__.rsplit(".")[:-1])
print("Sending %s job signal for: %s" % (when, app_name))
if when == "minutely":
signals.run_minutely_jobs.send(sender=app, app=app)
elif when == "quarter_hourly":
signals.run_quarter_hourly_jobs.send(sender=app, app=app)
elif when == "hourly":
signals.run_hourly_jobs.send(sender=app, app=app)
elif when == "daily":
signals.run_daily_jobs.send(sender=app, app=app)
elif when == "weekly":
signals.run_weekly_jobs.send(sender=app, app=app)
elif when == "monthly":
signals.run_monthly_jobs.send(sender=app, app=app)
elif when == "yearly":
signals.run_yearly_jobs.send(sender=app, app=app)
@signalcommand
def handle(self, *args, **options):
when = options["when"]
setup_logger(logger, self.stdout)
if options["list_jobs"]:
print_jobs(when, only_scheduled=True, show_when=True, show_appname=True)
elif when in self.when_options:
self.runjobs(when, options)
self.runjobs_by_signals(when, options)
else:
self.usage_msg()

View File

@@ -0,0 +1,295 @@
# -*- coding: utf-8 -*-
"""
runprofileserver.py
Starts a lightweight Web server with profiling enabled.
Credits for kcachegrind support taken from lsprofcalltree.py go to:
David Allouche
Jp Calderone & Itamar Shtull-Trauring
Johan Dahlin
"""
import sys
from datetime import datetime
from django.conf import settings
from django.contrib.staticfiles.handlers import StaticFilesHandler
from django.core.management.base import BaseCommand, CommandError
from django.core.servers.basehttp import get_internal_wsgi_application
from django_extensions.management.utils import signalcommand
USE_STATICFILES = "django.contrib.staticfiles" in settings.INSTALLED_APPS
class KCacheGrind:
def __init__(self, profiler):
self.data = profiler.getstats()
self.out_file = None
def output(self, out_file):
self.out_file = out_file
self.out_file.write("events: Ticks\n")
self._print_summary()
for entry in self.data:
self._entry(entry)
def _print_summary(self):
max_cost = 0
for entry in self.data:
totaltime = int(entry.totaltime * 1000)
max_cost = max(max_cost, totaltime)
self.out_file.write("summary: %d\n" % (max_cost,))
def _entry(self, entry):
out_file = self.out_file
code = entry.code
if isinstance(code, str):
out_file.write("fn=%s\n" % code)
else:
out_file.write("fl=%s\n" % code.co_filename)
out_file.write("fn=%s\n" % code.co_name)
inlinetime = int(entry.inlinetime * 1000)
if isinstance(code, str):
out_file.write("0 %s\n" % inlinetime)
else:
out_file.write("%d %d\n" % (code.co_firstlineno, inlinetime))
# recursive calls are counted in entry.calls
if entry.calls:
calls = entry.calls
else:
calls = []
if isinstance(code, str):
lineno = 0
else:
lineno = code.co_firstlineno
for subentry in calls:
self._subentry(lineno, subentry)
out_file.write("\n")
def _subentry(self, lineno, subentry):
out_file = self.out_file
code = subentry.code
if isinstance(code, str):
out_file.write("cfn=%s\n" % code)
out_file.write("calls=%d 0\n" % (subentry.callcount,))
else:
out_file.write("cfl=%s\n" % code.co_filename)
out_file.write("cfn=%s\n" % code.co_name)
out_file.write("calls=%d %d\n" % (subentry.callcount, code.co_firstlineno))
totaltime = int(subentry.totaltime * 1000)
out_file.write("%d %d\n" % (lineno, totaltime))
class Command(BaseCommand):
help = "Starts a lightweight Web server with profiling enabled."
args = "[optional port number, or ipaddr:port]"
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"addrport", nargs="?", help="Optional port number, or ipaddr:port"
)
parser.add_argument(
"--noreload",
action="store_false",
dest="use_reloader",
default=True,
help="Tells Django to NOT use the auto-reloader.",
)
parser.add_argument(
"--nothreading",
action="store_false",
dest="use_threading",
default=True,
help="Tells Django to NOT use threading.",
)
parser.add_argument(
"--prof-path",
dest="prof_path",
default="/tmp",
help="Specifies the directory which to save profile information in.",
)
parser.add_argument(
"--prof-file",
dest="prof_file",
default="{path}.{duration:06d}ms.{time}",
help='Set filename format, default if "{path}.{duration:06d}ms.{time}".',
)
parser.add_argument(
"--nomedia",
action="store_true",
dest="no_media",
default=False,
help="Do not profile MEDIA_URL",
)
parser.add_argument(
"--kcachegrind",
action="store_true",
dest="use_lsprof",
default=False,
help="Create kcachegrind compatible lsprof files, this requires "
"and automatically enables cProfile.",
)
if USE_STATICFILES:
parser.add_argument(
"--nostatic",
action="store_false",
dest="use_static_handler",
default=True,
help="Tells Django to NOT automatically serve static files "
"at STATIC_URL.",
)
parser.add_argument(
"--insecure",
action="store_true",
dest="insecure_serving",
default=False,
help="Allows serving static files even if DEBUG is False.",
)
@signalcommand
def handle(self, addrport="", *args, **options):
import django
import socket
import errno
from django.core.servers.basehttp import run
if not addrport:
addr = ""
port = "8000"
else:
try:
addr, port = addrport.split(":")
except ValueError:
addr, port = "", addrport
if not addr:
addr = "127.0.0.1"
if not port.isdigit():
raise CommandError("%r is not a valid port number." % port)
use_reloader = options["use_reloader"]
shutdown_message = options.get("shutdown_message", "")
no_media = options["no_media"]
quit_command = (sys.platform == "win32") and "CTRL-BREAK" or "CONTROL-C"
def inner_run():
import os
import time
import cProfile
USE_LSPROF = options["use_lsprof"]
prof_path = options["prof_path"]
prof_file = options["prof_file"]
if not prof_file.format(path="1", duration=2, time=3):
prof_file = "{path}.{duration:06d}ms.{time}"
print(
"Filename format is wrong. "
"Default format used: '{path}.{duration:06d}ms.{time}'."
)
def get_exclude_paths():
exclude_paths = []
media_url = getattr(settings, "MEDIA_URL", None)
if media_url:
exclude_paths.append(media_url)
static_url = getattr(settings, "STATIC_URL", None)
if static_url:
exclude_paths.append(static_url)
return exclude_paths
def make_profiler_handler(inner_handler):
def handler(environ, start_response):
path_info = environ["PATH_INFO"]
# when using something like a dynamic site middleware is could be
# necessary to refetch the exclude_paths every time since they could
# change per site.
if no_media and any(
path_info.startswith(p) for p in get_exclude_paths()
):
return inner_handler(environ, start_response)
path_name = path_info.strip("/").replace("/", ".") or "root"
profname = "%s.%d.prof" % (path_name, time.time())
profname = os.path.join(prof_path, profname)
prof = cProfile.Profile()
start = datetime.now()
try:
return prof.runcall(inner_handler, environ, start_response)
finally:
# seeing how long the request took is important!
elap = datetime.now() - start
elapms = elap.seconds * 1000.0 + elap.microseconds / 1000.0
if USE_LSPROF:
kg = KCacheGrind(prof)
with open(profname, "w") as f:
kg.output(f)
else:
prof.dump_stats(profname)
profname2 = prof_file.format(
path=path_name, duration=int(elapms), time=int(time.time())
)
profname2 = os.path.join(prof_path, "%s.prof" % profname2)
os.rename(profname, profname2)
return handler
print("Performing system checks...")
self.check(display_num_errors=True)
print(
"\nDjango version %s, using settings %r"
% (django.get_version(), settings.SETTINGS_MODULE)
)
print("Development server is running at http://%s:%s/" % (addr, port))
print("Quit the server with %s." % quit_command)
try:
handler = get_internal_wsgi_application()
if USE_STATICFILES:
use_static_handler = options["use_static_handler"]
insecure_serving = options["insecure_serving"]
if use_static_handler and (settings.DEBUG or insecure_serving):
handler = StaticFilesHandler(handler)
handler = make_profiler_handler(handler)
run(addr, int(port), handler, threading=options["use_threading"])
except socket.error as e:
# Use helpful error messages instead of ugly tracebacks.
ERRORS = {
errno.EACCES: "You don't have permission to access that port.",
errno.EADDRINUSE: "That port is already in use.",
errno.EADDRNOTAVAIL: "That IP address can't be assigned-to.",
}
try:
error_text = ERRORS[e.errno]
except (AttributeError, KeyError):
error_text = str(e)
sys.stderr.write(self.style.ERROR("Error: %s" % error_text) + "\n")
# Need to use an OS exit because sys.exit doesn't work in a thread
os._exit(1)
except KeyboardInterrupt:
if shutdown_message:
print(shutdown_message)
sys.exit(0)
if use_reloader:
try:
from django.utils.autoreload import run_with_reloader
run_with_reloader(inner_run)
except ImportError:
from django.utils import autoreload
autoreload.main(inner_run)
else:
inner_run()

View File

@@ -0,0 +1,356 @@
# -*- coding: utf-8 -*-
import os
import sys
import importlib
import inspect
import traceback
from argparse import ArgumentTypeError
from django.apps import apps
from django.conf import settings
from django.core.management.base import CommandError
from django_extensions.management.email_notifications import EmailNotificationCommand
from django_extensions.management.utils import signalcommand
class DirPolicyChoices:
NONE = "none"
EACH = "each"
ROOT = "root"
def check_is_directory(value):
if value is None or not os.path.isdir(value):
raise ArgumentTypeError("%s is not a directory!" % value)
return value
class BadCustomDirectoryException(Exception):
def __init__(self, value):
self.message = (
value + " If --dir-policy is custom than you must set correct directory in "
"--dir option or in settings.RUNSCRIPT_CHDIR"
)
def __str__(self):
return self.message
class Command(EmailNotificationCommand):
help = "Runs a script in django context."
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.current_directory = os.getcwd()
self.last_exit_code = 0
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument("script", nargs="+")
parser.add_argument(
"--fixtures",
action="store_true",
dest="infixtures",
default=False,
help="Also look in app.fixtures subdir",
)
parser.add_argument(
"--noscripts",
action="store_true",
dest="noscripts",
default=False,
help="Do not look in app.scripts subdir",
)
parser.add_argument(
"-s",
"--silent",
action="store_true",
dest="silent",
default=False,
help="Run silently, do not show errors and tracebacks."
" Also implies --continue-on-error.",
)
parser.add_argument(
"-c",
"--continue-on-error",
action="store_true",
dest="continue_on_error",
default=False,
help="Continue executing other scripts even though one has failed. "
"It will print a traceback unless --no-traceback or --silent are given "
"The exit code used when terminating will always be 1.",
)
parser.add_argument(
"--no-traceback",
action="store_true",
dest="no_traceback",
default=False,
help="Do not show tracebacks",
)
parser.add_argument(
"--script-args",
nargs="*",
type=str,
help="Space-separated argument list to be passed to the scripts. Note that "
"the same arguments will be passed to all named scripts.",
)
parser.add_argument(
"--dir-policy",
type=str,
choices=[
DirPolicyChoices.NONE,
DirPolicyChoices.EACH,
DirPolicyChoices.ROOT,
],
help="Policy of selecting scripts execution directory: "
"none - start all scripts in current directory "
"each - start all scripts in their directories "
"root - start all scripts in BASE_DIR directory ",
)
parser.add_argument(
"--chdir",
type=check_is_directory,
help="If dir-policy option is set to custom, than this option determines "
"script execution directory.",
)
@signalcommand
def handle(self, *args, **options):
NOTICE = self.style.SQL_TABLE
NOTICE2 = self.style.SQL_FIELD
ERROR = self.style.ERROR
ERROR2 = self.style.NOTICE
subdirs = []
scripts = options["script"]
if not options["noscripts"]:
subdirs.append(getattr(settings, "RUNSCRIPT_SCRIPT_DIR", "scripts"))
if options["infixtures"]:
subdirs.append("fixtures")
verbosity = options["verbosity"]
show_traceback = options["traceback"]
no_traceback = options["no_traceback"]
continue_on_error = options["continue_on_error"]
if no_traceback:
show_traceback = False
else:
show_traceback = True
silent = options["silent"]
if silent:
verbosity = 0
continue_on_error = True
email_notifications = options["email_notifications"]
if len(subdirs) < 1:
print(NOTICE("No subdirs to run left."))
return
if len(scripts) < 1:
print(ERROR("Script name required."))
return
def get_directory_from_chdir():
directory = options["chdir"] or getattr(settings, "RUNSCRIPT_CHDIR", None)
try:
check_is_directory(directory)
except ArgumentTypeError as e:
raise BadCustomDirectoryException(str(e))
return directory
def get_directory_basing_on_policy(script_module):
policy = options["dir_policy"] or getattr(
settings, "RUNSCRIPT_CHDIR_POLICY", DirPolicyChoices.NONE
)
if policy == DirPolicyChoices.ROOT:
return settings.BASE_DIR
elif policy == DirPolicyChoices.EACH:
return os.path.dirname(inspect.getfile(script_module))
else:
return self.current_directory
def set_directory(script_module):
if options["chdir"]:
directory = get_directory_from_chdir()
elif options["dir_policy"]:
directory = get_directory_basing_on_policy(script_module)
elif getattr(settings, "RUNSCRIPT_CHDIR", None):
directory = get_directory_from_chdir()
else:
directory = get_directory_basing_on_policy(script_module)
os.chdir(os.path.abspath(directory))
def run_script(mod, *script_args):
exit_code = None
try:
set_directory(mod)
exit_code = mod.run(*script_args)
if isinstance(exit_code, bool):
# convert boolean True to exit-code 0 and False to exit-code 1
exit_code = 1 if exit_code else 0
if isinstance(exit_code, int):
if exit_code != 0:
try:
raise CommandError(
"'%s' failed with exit code %s"
% (mod.__name__, exit_code),
returncode=exit_code,
)
except TypeError:
raise CommandError(
"'%s' failed with exit code %s"
% (mod.__name__, exit_code)
)
if email_notifications:
self.send_email_notification(notification_id=mod.__name__)
except Exception as e:
if isinstance(e, CommandError) and hasattr(e, "returncode"):
exit_code = e.returncode
self.last_exit_code = exit_code if isinstance(exit_code, int) else 1
if silent:
return
if verbosity > 0:
print(ERROR("Exception while running run() in '%s'" % mod.__name__))
if continue_on_error:
if show_traceback:
traceback.print_exc()
return
if email_notifications:
self.send_email_notification(
notification_id=mod.__name__, include_traceback=True
)
if no_traceback:
raise CommandError(repr(e))
raise
def my_import(parent_package, module_name):
full_module_path = "%s.%s" % (parent_package, module_name)
if verbosity > 1:
print(NOTICE("Check for %s" % full_module_path))
# Try importing the parent package first
try:
importlib.import_module(parent_package)
except ImportError as e:
if str(e).startswith("No module named"):
# No need to proceed if the parent package doesn't exist
return False
try:
t = importlib.import_module(full_module_path)
except ImportError as e:
# The parent package exists, but the module doesn't
try:
if importlib.util.find_spec(full_module_path) is None:
return False
except Exception:
module_file = (
os.path.join(settings.BASE_DIR, *full_module_path.split("."))
+ ".py"
)
if not os.path.isfile(module_file):
return False
if silent:
return False
if show_traceback:
traceback.print_exc()
if verbosity > 0:
print(
ERROR("Cannot import module '%s': %s." % (full_module_path, e))
)
return False
if hasattr(t, "run"):
if verbosity > 1:
print(NOTICE2("Found script '%s' ..." % full_module_path))
return t
else:
if verbosity > 1:
print(
ERROR2(
"Found script '%s' but no run() function found."
% full_module_path
)
)
def find_modules_for_script(script):
"""Find script module which contains 'run' attribute"""
modules = []
# first look in apps
for app in apps.get_app_configs():
for subdir in subdirs:
mod = my_import("%s.%s" % (app.name, subdir), script)
if mod:
modules.append(mod)
# try direct import
if script.find(".") != -1:
parent, mod_name = script.rsplit(".", 1)
mod = my_import(parent, mod_name)
if mod:
modules.append(mod)
else:
# try app.DIR.script import
for subdir in subdirs:
mod = my_import(subdir, script)
if mod:
modules.append(mod)
return modules
if options["script_args"]:
script_args = options["script_args"]
else:
script_args = []
# first pass to check if all scripts can be found
script_to_run = []
for script in scripts:
script_modules = find_modules_for_script(script)
if not script_modules:
self.last_exit_code = 1
if verbosity > 0 and not silent:
print(ERROR("No (valid) module for script '%s' found" % script))
continue
script_to_run.extend(script_modules)
if self.last_exit_code:
if verbosity < 2 and not silent:
print(
ERROR("Try running with a higher verbosity level like: -v2 or -v3")
)
if not continue_on_error:
script_to_run = []
for script_mod in script_to_run:
if verbosity > 1:
print(NOTICE2("Running script '%s' ..." % script_mod.__name__))
run_script(script_mod, *script_args)
if self.last_exit_code != 0:
if silent:
if hasattr(self, "running_tests"):
return
sys.exit(self.last_exit_code)
try:
raise CommandError(
"An error has occurred running scripts. See errors above.",
returncode=self.last_exit_code,
)
except TypeError:
# Django < 3.1 fallback
if self.last_exit_code == 1:
# if exit_code is 1 we can still raise CommandError without
# returncode argument
raise CommandError(
"An error has occurred running scripts. See errors above."
)
print(ERROR("An error has occurred running scripts. See errors above."))
if hasattr(self, "running_tests"):
return
sys.exit(self.last_exit_code)

View File

@@ -0,0 +1,756 @@
# -*- coding: utf-8 -*-
import logging
import os
import re
import socket
import sys
import traceback
import webbrowser
import functools
from pathlib import Path
from typing import List, Set # NOQA
import django
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError, SystemCheckError
from django.core.management.color import color_style
from django.core.servers.basehttp import get_internal_wsgi_application
from django.dispatch import Signal
from django.template.autoreload import get_template_directories, reset_loaders
from django.utils.autoreload import file_changed, get_reloader
from django.views import debug as django_views_debug
try:
if "whitenoise.runserver_nostatic" in settings.INSTALLED_APPS:
USE_STATICFILES = False
else:
from django.contrib.staticfiles.handlers import StaticFilesHandler
USE_STATICFILES = True
except ImportError:
USE_STATICFILES = False
try:
from werkzeug import run_simple
from werkzeug.debug import DebuggedApplication
from werkzeug.serving import WSGIRequestHandler as _WSGIRequestHandler
from werkzeug.serving import make_ssl_devcert
from werkzeug._internal import _log # type: ignore
from werkzeug import _reloader
HAS_WERKZEUG = True
except ImportError:
HAS_WERKZEUG = False
try:
import OpenSSL # NOQA
HAS_OPENSSL = True
except ImportError:
HAS_OPENSSL = False
from django_extensions.management.technical_response import null_technical_500_response
from django_extensions.management.utils import (
RedirectHandler,
has_ipdb,
setup_logger,
signalcommand,
)
from django_extensions.management.debug_cursor import monkey_patch_cursordebugwrapper
runserver_plus_started = Signal()
naiveip_re = re.compile(
r"""^(?:
(?P<addr>
(?P<ipv4>\d{1,3}(?:\.\d{1,3}){3}) | # IPv4 address
(?P<ipv6>\[[a-fA-F0-9:]+\]) | # IPv6 address
(?P<fqdn>[a-zA-Z0-9-]+(?:\.[a-zA-Z0-9-]+)*) # FQDN
):)?(?P<port>\d+)$""",
re.X,
)
# 7-bit C1 ANSI sequences (https://stackoverflow.com/questions/14693701/how-can-i-remove-the-ansi-escape-sequences-from-a-string-in-python)
ansi_escape = re.compile(
r"""
\x1B # ESC
(?: # 7-bit C1 Fe (except CSI)
[@-Z\\-_]
| # or [ for CSI, followed by a control sequence
\[
[0-?]* # Parameter bytes
[ -/]* # Intermediate bytes
[@-~] # Final byte
)
""",
re.VERBOSE,
)
DEFAULT_PORT = "8000"
DEFAULT_POLLER_RELOADER_INTERVAL = getattr(
settings, "RUNSERVERPLUS_POLLER_RELOADER_INTERVAL", 1
)
DEFAULT_POLLER_RELOADER_TYPE = getattr(
settings, "RUNSERVERPLUS_POLLER_RELOADER_TYPE", "auto"
)
logger = logging.getLogger(__name__)
_error_files = set() # type: Set[str]
def get_all_template_files() -> Set[str]:
template_list = set()
for template_dir in get_template_directories():
for base_dir, _, filenames in os.walk(template_dir):
for filename in filenames:
template_list.add(os.path.join(base_dir, filename))
return template_list
if HAS_WERKZEUG:
# Monkey patch the reloader to support adding more files to extra_files
for name, reloader_loop_klass in _reloader.reloader_loops.items():
class WrappedReloaderLoop(reloader_loop_klass): # type: ignore
def __init__(self, *args, **kwargs):
self._template_files: Set[str] = get_all_template_files()
super().__init__(*args, **kwargs)
self._extra_files = self.extra_files
@property
def extra_files(self):
template_files = get_all_template_files()
# reset loaders if there are new files detected
if len(self._template_files) != len(template_files):
changed = template_files.difference(self._template_files)
for filename in changed:
_log(
"info",
f" * New file {filename} added, reset template loaders",
)
self.register_file_changed(filename)
reset_loaders()
self._template_files = template_files
return self._extra_files.union(_error_files, template_files)
@extra_files.setter
def extra_files(self, extra_files):
self._extra_files = extra_files
def trigger_reload(self, filename: str) -> None:
path = Path(filename)
results = file_changed.send(sender=self, file_path=path)
if not any(res[1] for res in results):
super().trigger_reload(filename)
else:
_log(
"info",
f" * Detected change in {filename!r}, reset template loaders",
)
self.register_file_changed(filename)
def register_file_changed(self, filename):
if hasattr(self, "mtimes"):
mtime = os.stat(filename).st_mtime
self.mtimes[filename] = mtime
_reloader.reloader_loops[name] = WrappedReloaderLoop
def gen_filenames():
return get_reloader().watched_files()
def check_errors(fn):
# Inspired by https://github.com/django/django/blob/master/django/utils/autoreload.py
@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except Exception:
_exception = sys.exc_info()
_, ev, tb = _exception
if getattr(ev, "filename", None) is None:
# get the filename from the last item in the stack
filename = traceback.extract_tb(tb)[-1][0]
else:
filename = ev.filename
if filename not in _error_files:
_error_files.add(filename)
raise
return wrapper
class Command(BaseCommand):
help = "Starts a lightweight Web server for development."
# Validation is called explicitly each time the server is reloaded.
requires_system_checks: List[str] = []
DEFAULT_CRT_EXTENSION = ".crt"
DEFAULT_KEY_EXTENSION = ".key"
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"addrport", nargs="?", help="Optional port number, or ipaddr:port"
)
parser.add_argument(
"--ipv6",
"-6",
action="store_true",
dest="use_ipv6",
default=False,
help="Tells Django to use a IPv6 address.",
)
parser.add_argument(
"--noreload",
action="store_false",
dest="use_reloader",
default=True,
help="Tells Django to NOT use the auto-reloader.",
)
parser.add_argument(
"--browser",
action="store_true",
dest="open_browser",
help="Tells Django to open a browser.",
)
parser.add_argument(
"--nothreading",
action="store_false",
dest="threaded",
help="Do not run in multithreaded mode.",
)
parser.add_argument(
"--threaded",
action="store_true",
dest="threaded",
help="Run in multithreaded mode.",
)
parser.add_argument(
"--output",
dest="output_file",
default=None,
help="Specifies an output file to send a copy of all messages "
"(not flushed immediately).",
)
parser.add_argument(
"--print-sql",
action="store_true",
default=False,
help="Print SQL queries as they're executed",
)
parser.add_argument(
"--truncate-sql",
action="store",
type=int,
help="Truncate SQL queries to a number of characters.",
)
parser.add_argument(
"--print-sql-location",
action="store_true",
default=False,
help="Show location in code where SQL query generated from",
)
cert_group = parser.add_mutually_exclusive_group()
cert_group.add_argument(
"--cert",
dest="cert_path",
action="store",
type=str,
help="Deprecated alias for --cert-file option.",
)
cert_group.add_argument(
"--cert-file",
dest="cert_path",
action="store",
type=str,
help="SSL .crt file path. If not provided path from --key-file will be "
"selected. Either --cert-file or --key-file must be provided to use SSL.",
)
parser.add_argument(
"--key-file",
dest="key_file_path",
action="store",
type=str,
help="SSL .key file path. If not provided path from --cert-file "
"will be selected. Either --cert-file or --key-file must be provided "
"to use SSL.",
)
parser.add_argument(
"--extra-file",
dest="extra_files",
action="append",
type=str,
default=[],
help="auto-reload whenever the given file changes too"
" (can be specified multiple times)",
)
parser.add_argument(
"--exclude-pattern",
dest="exclude_patterns",
action="append",
type=str,
default=[],
help="ignore reload on changes to files matching this pattern"
" (can be specified multiple times)",
)
parser.add_argument(
"--reloader-interval",
dest="reloader_interval",
action="store",
type=int,
default=DEFAULT_POLLER_RELOADER_INTERVAL,
help="After how many seconds auto-reload should scan for updates"
" in poller-mode [default=%s]" % DEFAULT_POLLER_RELOADER_INTERVAL,
)
parser.add_argument(
"--reloader-type",
dest="reloader_type",
action="store",
type=str,
default=DEFAULT_POLLER_RELOADER_TYPE,
help="Werkzeug reloader type "
"[options are auto, watchdog, or stat, default=%s]"
% DEFAULT_POLLER_RELOADER_TYPE,
)
parser.add_argument(
"--pdb",
action="store_true",
dest="pdb",
default=False,
help="Drop into pdb shell at the start of any view.",
)
parser.add_argument(
"--ipdb",
action="store_true",
dest="ipdb",
default=False,
help="Drop into ipdb shell at the start of any view.",
)
parser.add_argument(
"--pm",
action="store_true",
dest="pm",
default=False,
help="Drop into (i)pdb shell if an exception is raised in a view.",
)
parser.add_argument(
"--startup-messages",
dest="startup_messages",
action="store",
default="reload",
help="When to show startup messages: "
"reload [default], once, always, never.",
)
parser.add_argument(
"--keep-meta-shutdown",
dest="keep_meta_shutdown_func",
action="store_true",
default=False,
help="Keep request.META['werkzeug.server.shutdown'] function which is "
"automatically removed because Django debug pages tries to call the "
"function and unintentionally shuts down the Werkzeug server.",
)
parser.add_argument(
"--nopin",
dest="nopin",
action="store_true",
default=False,
help="Disable the PIN in werkzeug. USE IT WISELY!",
)
if USE_STATICFILES:
parser.add_argument(
"--nostatic",
action="store_false",
dest="use_static_handler",
default=True,
help="Tells Django to NOT automatically serve static files.",
)
parser.add_argument(
"--insecure",
action="store_true",
dest="insecure_serving",
default=False,
help="Allows serving static files even if DEBUG is False.",
)
@signalcommand
def handle(self, *args, **options):
addrport = options["addrport"]
startup_messages = options["startup_messages"]
if startup_messages == "reload":
self.show_startup_messages = os.environ.get("RUNSERVER_PLUS_SHOW_MESSAGES")
elif startup_messages == "once":
self.show_startup_messages = not os.environ.get(
"RUNSERVER_PLUS_SHOW_MESSAGES"
)
elif startup_messages == "never":
self.show_startup_messages = False
else:
self.show_startup_messages = True
os.environ["RUNSERVER_PLUS_SHOW_MESSAGES"] = "1"
setup_logger(
logger, self.stderr, filename=options["output_file"]
) # , fmt="[%(name)s] %(message)s")
logredirect = RedirectHandler(__name__)
# Redirect werkzeug log items
werklogger = logging.getLogger("werkzeug")
werklogger.setLevel(logging.INFO)
werklogger.addHandler(logredirect)
werklogger.propagate = False
pdb_option = options["pdb"]
ipdb_option = options["ipdb"]
pm = options["pm"]
try:
from django_pdb.middleware import PdbMiddleware
except ImportError:
if pdb_option or ipdb_option or pm:
raise CommandError(
"django-pdb is required for --pdb, --ipdb and --pm options. "
"Please visit https://pypi.python.org/pypi/django-pdb or install "
"via pip. (pip install django-pdb)"
)
pm = False
else:
# Add pdb middleware if --pdb is specified or if in DEBUG mode
if pdb_option or ipdb_option or settings.DEBUG:
middleware = "django_pdb.middleware.PdbMiddleware"
settings_middleware = (
getattr(settings, "MIDDLEWARE", None) or settings.MIDDLEWARE_CLASSES
)
if middleware not in settings_middleware:
if isinstance(settings_middleware, tuple):
settings_middleware += (middleware,)
else:
settings_middleware += [middleware]
# If --pdb is specified then always break at the start of views.
# Otherwise break only if a 'pdb' query parameter is set in the url
if pdb_option:
PdbMiddleware.always_break = "pdb"
elif ipdb_option:
PdbMiddleware.always_break = "ipdb"
def postmortem(request, exc_type, exc_value, tb):
if has_ipdb():
import ipdb
p = ipdb
else:
import pdb
p = pdb
print(
"Exception occured: %s, %s" % (exc_type, exc_value), file=sys.stderr
)
p.post_mortem(tb)
# usurp django's handler
django_views_debug.technical_500_response = (
postmortem if pm else null_technical_500_response
)
self.use_ipv6 = options["use_ipv6"]
if self.use_ipv6 and not socket.has_ipv6:
raise CommandError("Your Python does not support IPv6.")
self._raw_ipv6 = False
if not addrport:
try:
addrport = settings.RUNSERVERPLUS_SERVER_ADDRESS_PORT
except AttributeError:
pass
if not addrport:
self.addr = ""
self.port = DEFAULT_PORT
else:
m = re.match(naiveip_re, addrport)
if m is None:
raise CommandError(
'"%s" is not a valid port number or address:port pair.' % addrport
)
self.addr, _ipv4, _ipv6, _fqdn, self.port = m.groups()
if not self.port.isdigit():
raise CommandError("%r is not a valid port number." % self.port)
if self.addr:
if _ipv6:
self.addr = self.addr[1:-1]
self.use_ipv6 = True
self._raw_ipv6 = True
elif self.use_ipv6 and not _fqdn:
raise CommandError('"%s" is not a valid IPv6 address.' % self.addr)
if not self.addr:
self.addr = "::1" if self.use_ipv6 else "127.0.0.1"
self._raw_ipv6 = True
truncate = None if options["truncate_sql"] == 0 else options["truncate_sql"]
with monkey_patch_cursordebugwrapper(
print_sql=options["print_sql"],
print_sql_location=options["print_sql_location"],
truncate=truncate,
logger=logger.info,
confprefix="RUNSERVER_PLUS",
):
self.inner_run(options)
def get_handler(self, *args, **options):
"""Return the default WSGI handler for the runner."""
return get_internal_wsgi_application()
def get_error_handler(self, exc, **options):
def application(env, start_response):
if isinstance(exc, SystemCheckError):
error_message = ansi_escape.sub("", str(exc))
raise SystemCheckError(error_message)
raise exc
return application
def inner_run(self, options):
if not HAS_WERKZEUG:
raise CommandError(
"Werkzeug is required to use runserver_plus. "
"Please visit https://werkzeug.palletsprojects.com/ or install via pip."
" (pip install Werkzeug)"
)
# Set colored output
if settings.DEBUG:
try:
set_werkzeug_log_color()
except (
Exception
): # We are dealing with some internals, anything could go wrong
if self.show_startup_messages:
print(
"Wrapping internal werkzeug logger "
"for color highlighting has failed!"
)
class WSGIRequestHandler(_WSGIRequestHandler):
def make_environ(self):
environ = super().make_environ()
if (
not options["keep_meta_shutdown_func"]
and "werkzeug.server.shutdown" in environ
):
del environ["werkzeug.server.shutdown"]
remote_user = os.getenv("REMOTE_USER")
if remote_user is not None:
environ["REMOTE_USER"] = remote_user
return environ
threaded = options["threaded"]
use_reloader = options["use_reloader"]
open_browser = options["open_browser"]
quit_command = "CONTROL-C" if sys.platform != "win32" else "CTRL-BREAK"
reloader_interval = options["reloader_interval"]
reloader_type = options["reloader_type"]
self.extra_files = set(options["extra_files"])
exclude_patterns = set(options["exclude_patterns"])
self.nopin = options["nopin"]
if self.show_startup_messages:
print("Performing system checks...\n")
try:
check_errors(self.check)(display_num_errors=self.show_startup_messages)
check_errors(self.check_migrations)()
handler = check_errors(self.get_handler)(**options)
except Exception as exc:
self.stderr.write("Error occurred during checks: %r" % exc, ending="\n\n")
handler = self.get_error_handler(exc, **options)
if USE_STATICFILES:
use_static_handler = options["use_static_handler"]
insecure_serving = options["insecure_serving"]
if use_static_handler and (settings.DEBUG or insecure_serving):
handler = StaticFilesHandler(handler)
if options["cert_path"] or options["key_file_path"]:
if not HAS_OPENSSL:
raise CommandError(
"Python OpenSSL Library is "
"required to use runserver_plus with ssl support. "
"Install via pip (pip install pyOpenSSL)."
)
certfile, keyfile = self.determine_ssl_files_paths(options)
dir_path, root = os.path.split(certfile)
root, _ = os.path.splitext(root)
try:
if os.path.exists(certfile) and os.path.exists(keyfile):
ssl_context = (certfile, keyfile)
else: # Create cert, key files ourselves.
ssl_context = make_ssl_devcert(
os.path.join(dir_path, root), host="localhost"
)
except ImportError:
if self.show_startup_messages:
print(
"Werkzeug version is less than 0.9, trying adhoc certificate."
)
ssl_context = "adhoc"
else:
ssl_context = None
bind_url = "%s://%s:%s/" % (
"https" if ssl_context else "http",
self.addr if not self._raw_ipv6 else "[%s]" % self.addr,
self.port,
)
if self.show_startup_messages:
print(
"\nDjango version %s, using settings %r"
% (django.get_version(), settings.SETTINGS_MODULE)
)
print("Development server is running at %s" % (bind_url,))
print("Using the Werkzeug debugger (https://werkzeug.palletsprojects.com/)")
print("Quit the server with %s." % quit_command)
if open_browser:
webbrowser.open(bind_url)
if use_reloader and settings.USE_I18N:
self.extra_files |= set(
filter(lambda filename: str(filename).endswith(".mo"), gen_filenames())
)
if getattr(settings, "RUNSERVER_PLUS_EXTRA_FILES", []):
self.extra_files |= set(settings.RUNSERVER_PLUS_EXTRA_FILES)
exclude_patterns |= set(
getattr(settings, "RUNSERVER_PLUS_EXCLUDE_PATTERNS", [])
)
# Werkzeug needs to be clued in its the main instance if running
# without reloader or else it won't show key.
# https://git.io/vVIgo
if not use_reloader:
os.environ["WERKZEUG_RUN_MAIN"] = "true"
# Don't run a second instance of the debugger / reloader
# See also: https://github.com/django-extensions/django-extensions/issues/832
if os.environ.get("WERKZEUG_RUN_MAIN") != "true":
if self.nopin:
os.environ["WERKZEUG_DEBUG_PIN"] = "off"
handler = DebuggedApplication(handler, True)
runserver_plus_started.send(sender=self)
run_simple(
self.addr,
int(self.port),
handler,
use_reloader=use_reloader,
use_debugger=True,
extra_files=self.extra_files,
exclude_patterns=exclude_patterns,
reloader_interval=reloader_interval,
reloader_type=reloader_type,
threaded=threaded,
request_handler=WSGIRequestHandler,
ssl_context=ssl_context,
)
@classmethod
def determine_ssl_files_paths(cls, options):
key_file_path = os.path.expanduser(options.get("key_file_path") or "")
cert_path = os.path.expanduser(options.get("cert_path") or "")
cert_file = cls._determine_path_for_file(
cert_path, key_file_path, cls.DEFAULT_CRT_EXTENSION
)
key_file = cls._determine_path_for_file(
key_file_path, cert_path, cls.DEFAULT_KEY_EXTENSION
)
return cert_file, key_file
@classmethod
def _determine_path_for_file(
cls, current_file_path, other_file_path, expected_extension
):
directory = cls._get_directory_basing_on_file_paths(
current_file_path, other_file_path
)
file_name = cls._get_file_name(current_file_path) or cls._get_file_name(
other_file_path
)
extension = cls._get_extension(current_file_path) or expected_extension
return os.path.join(directory, file_name + extension)
@classmethod
def _get_directory_basing_on_file_paths(cls, current_file_path, other_file_path):
return (
cls._get_directory(current_file_path)
or cls._get_directory(other_file_path)
or os.getcwd()
)
@classmethod
def _get_directory(cls, file_path):
return os.path.split(file_path)[0]
@classmethod
def _get_file_name(cls, file_path):
return os.path.splitext(os.path.split(file_path)[1])[0]
@classmethod
def _get_extension(cls, file_path):
return os.path.splitext(file_path)[1]
def set_werkzeug_log_color():
"""Try to set color to the werkzeug log."""
_style = color_style()
_orig_log = _WSGIRequestHandler.log
def werk_log(self, type, message, *args):
try:
msg = "%s - - [%s] %s" % (
self.address_string(),
self.log_date_time_string(),
message % args,
)
http_code = str(args[1])
except Exception:
return _orig_log(type, message, *args)
# Utilize terminal colors, if available
if http_code[0] == "2":
# Put 2XX first, since it should be the common case
msg = _style.HTTP_SUCCESS(msg)
elif http_code[0] == "1":
msg = _style.HTTP_INFO(msg)
elif http_code == "304":
msg = _style.HTTP_NOT_MODIFIED(msg)
elif http_code[0] == "3":
msg = _style.HTTP_REDIRECT(msg)
elif http_code == "404":
msg = _style.HTTP_NOT_FOUND(msg)
elif http_code[0] == "4":
msg = _style.HTTP_BAD_REQUEST(msg)
else:
# Any 5XX, or any other response
msg = _style.HTTP_SERVER_ERROR(msg)
_log(type, msg)
_WSGIRequestHandler.log = werk_log

View File

@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
import socket
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
from django.apps import apps
from django_extensions.management.utils import signalcommand
class Command(BaseCommand):
help = "Set parameters of the default django.contrib.sites Site"
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--name", dest="site_name", default=None, help="Use this as site name."
)
parser.add_argument(
"--domain",
dest="site_domain",
default=None,
help="Use this as site domain.",
)
parser.add_argument(
"--system-fqdn",
dest="set_as_system_fqdn",
default=False,
action="store_true",
help="Use the systems FQDN (Fully Qualified Domain Name) as name "
"and domain. Can be used in combination with --name",
)
@signalcommand
def handle(self, *args, **options):
if not apps.is_installed("django.contrib.sites"):
raise CommandError("The sites framework is not installed.")
from django.contrib.sites.models import Site
try:
site = Site.objects.get(pk=settings.SITE_ID)
except Site.DoesNotExist:
raise CommandError(
"Default site with pk=%s does not exist" % settings.SITE_ID
)
else:
name = options["site_name"]
domain = options["site_domain"]
set_as_system_fqdn = options["set_as_system_fqdn"]
if all([domain, set_as_system_fqdn]):
raise CommandError(
"The set_as_system_fqdn cannot be used with domain option."
) # noqa
if set_as_system_fqdn:
domain = socket.getfqdn()
if not domain:
raise CommandError("Cannot find systems FQDN")
if name is None:
name = domain
update_kwargs = {}
if name and name != site.name:
update_kwargs["name"] = name
if domain and domain != site.domain:
update_kwargs["domain"] = domain
if update_kwargs:
Site.objects.filter(pk=settings.SITE_ID).update(**update_kwargs)
site = Site.objects.get(pk=settings.SITE_ID)
print(
"Updated default site. You might need to restart django as sites"
" are cached aggressively."
)
else:
print("Nothing to update (need --name, --domain and/or --system-fqdn)")
print("Default Site:")
print("\tid = %s" % site.id)
print("\tname = %s" % site.name)
print("\tdomain = %s" % site.domain)

View File

@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
"""
set_fake_emails.py
Give all users a new email account. Useful for testing in a
development environment. As such, this command is only available when
setting.DEBUG is True.
"""
from typing import List
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import signalcommand
DEFAULT_FAKE_EMAIL = "%(username)s@example.com"
class Command(BaseCommand):
help = (
"DEBUG only: give all users a new email based on their account data "
'("%s" by default). '
"Possible parameters are: username, first_name, last_name"
) % (DEFAULT_FAKE_EMAIL,)
requires_system_checks: List[str] = []
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--email",
dest="default_email",
default=DEFAULT_FAKE_EMAIL,
help="Use this as the new email format.",
)
parser.add_argument(
"-a",
"--no-admin",
action="store_true",
dest="no_admin",
default=False,
help="Do not change administrator accounts",
)
parser.add_argument(
"-s",
"--no-staff",
action="store_true",
dest="no_staff",
default=False,
help="Do not change staff accounts",
)
parser.add_argument(
"--include",
dest="include_regexp",
default=None,
help="Include usernames matching this regexp.",
)
parser.add_argument(
"--exclude",
dest="exclude_regexp",
default=None,
help="Exclude usernames matching this regexp.",
)
parser.add_argument(
"--include-groups",
dest="include_groups",
default=None,
help=(
"Include users matching this group. "
"(use comma separation for multiple groups)"
),
)
parser.add_argument(
"--exclude-groups",
dest="exclude_groups",
default=None,
help=(
"Exclude users matching this group. "
"(use comma separation for multiple groups)"
),
)
@signalcommand
def handle(self, *args, **options):
if not settings.DEBUG:
raise CommandError("Only available in debug mode")
from django.contrib.auth.models import Group
email = options["default_email"]
include_regexp = options["include_regexp"]
exclude_regexp = options["exclude_regexp"]
include_groups = options["include_groups"]
exclude_groups = options["exclude_groups"]
no_admin = options["no_admin"]
no_staff = options["no_staff"]
User = get_user_model()
users = User.objects.all()
if no_admin:
users = users.exclude(is_superuser=True)
if no_staff:
users = users.exclude(is_staff=True)
if exclude_groups:
groups = Group.objects.filter(name__in=exclude_groups.split(","))
if groups:
users = users.exclude(groups__in=groups)
else:
raise CommandError("No groups matches filter: %s" % exclude_groups)
if include_groups:
groups = Group.objects.filter(name__in=include_groups.split(","))
if groups:
users = users.filter(groups__in=groups)
else:
raise CommandError("No groups matches filter: %s" % include_groups)
if exclude_regexp:
users = users.exclude(username__regex=exclude_regexp)
if include_regexp:
users = users.filter(username__regex=include_regexp)
for user in users:
user.email = email % {
"username": user.username,
"first_name": user.first_name,
"last_name": user.last_name,
}
user.save()
print("Changed %d emails" % users.count())

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
"""
set_fake_passwords.py
Reset all user passwords to a common value. Useful for testing in a
development environment. As such, this command is only available when
setting.DEBUG is True.
"""
from typing import List
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand, CommandError
from django_extensions.management.utils import signalcommand
DEFAULT_FAKE_PASSWORD = "password"
class Command(BaseCommand):
help = 'DEBUG only: sets all user passwords to a common value ("%s" by default)' % (
DEFAULT_FAKE_PASSWORD,
)
requires_system_checks: List[str] = []
def add_arguments(self, parser):
super().add_arguments(parser)
parser.add_argument(
"--prompt",
dest="prompt_passwd",
default=False,
action="store_true",
help="Prompts for the new password to apply to all users",
)
parser.add_argument(
"--password",
dest="default_passwd",
default=DEFAULT_FAKE_PASSWORD,
help="Use this as default password.",
)
@signalcommand
def handle(self, *args, **options):
if not settings.DEBUG:
raise CommandError("Only available in debug mode")
if options["prompt_passwd"]:
from getpass import getpass
passwd = getpass("Password: ")
if not passwd:
raise CommandError("You must enter a valid password")
else:
passwd = options["default_passwd"]
User = get_user_model()
user = User()
user.set_password(passwd)
count = User.objects.all().update(password=user.password)
print("Reset %d passwords" % count)

Some files were not shown because too many files have changed in this diff Show More