109 lines
4.4 KiB
Python
109 lines
4.4 KiB
Python
"""
|
|
Management command to train predictive models
|
|
"""
|
|
from django.core.management.base import BaseCommand, CommandError
|
|
from analytics_predictive_insights.models import PredictiveModel
|
|
from analytics_predictive_insights.ml.predictive_models import PredictiveModelService
|
|
|
|
|
|
class Command(BaseCommand):
|
|
"""Train predictive models"""
|
|
|
|
help = 'Train predictive models that are in training status'
|
|
|
|
def add_arguments(self, parser):
|
|
parser.add_argument(
|
|
'--model-id',
|
|
type=str,
|
|
help='Train a specific model ID only'
|
|
)
|
|
parser.add_argument(
|
|
'--force',
|
|
action='store_true',
|
|
help='Force retraining of active models'
|
|
)
|
|
|
|
def handle(self, *args, **options):
|
|
"""Handle the command execution"""
|
|
model_id = options.get('model_id')
|
|
force = options.get('force', False)
|
|
|
|
try:
|
|
# Initialize predictive model service
|
|
model_service = PredictiveModelService()
|
|
|
|
# Get models to train
|
|
if model_id:
|
|
models = PredictiveModel.objects.filter(id=model_id)
|
|
if not models.exists():
|
|
raise CommandError(f'No model found with ID: {model_id}')
|
|
else:
|
|
if force:
|
|
models = PredictiveModel.objects.filter(
|
|
model_type__in=[
|
|
'INCIDENT_PREDICTION',
|
|
'SEVERITY_PREDICTION',
|
|
'RESOLUTION_TIME_PREDICTION',
|
|
'COST_PREDICTION'
|
|
]
|
|
)
|
|
else:
|
|
models = PredictiveModel.objects.filter(status='TRAINING')
|
|
|
|
self.stdout.write(f'Training {models.count()} models...')
|
|
|
|
total_trained = 0
|
|
total_failed = 0
|
|
|
|
for model in models:
|
|
try:
|
|
self.stdout.write(f'Training model: {model.name}...')
|
|
|
|
result = model_service.train_model(str(model.id))
|
|
|
|
if result['success']:
|
|
total_trained += 1
|
|
self.stdout.write(
|
|
self.style.SUCCESS(f'✓ Successfully trained {model.name}')
|
|
)
|
|
|
|
# Display metrics
|
|
if 'metrics' in result:
|
|
metrics = result['metrics']
|
|
self.stdout.write(f' Accuracy: {metrics.get("accuracy", "N/A")}')
|
|
self.stdout.write(f' Precision: {metrics.get("precision", "N/A")}')
|
|
self.stdout.write(f' Recall: {metrics.get("recall", "N/A")}')
|
|
self.stdout.write(f' F1 Score: {metrics.get("f1_score", "N/A")}')
|
|
self.stdout.write(f' R2 Score: {metrics.get("r2_score", "N/A")}')
|
|
|
|
self.stdout.write(f' Training samples: {result.get("training_samples", "N/A")}')
|
|
self.stdout.write(f' Training duration: {result.get("training_duration", "N/A")} seconds')
|
|
|
|
else:
|
|
total_failed += 1
|
|
self.stdout.write(
|
|
self.style.ERROR(f'✗ Failed to train {model.name}: {result.get("error", "Unknown error")}')
|
|
)
|
|
|
|
except Exception as e:
|
|
total_failed += 1
|
|
self.stdout.write(
|
|
self.style.ERROR(f'✗ Error training {model.name}: {str(e)}')
|
|
)
|
|
|
|
self.stdout.write('\nTraining Summary:')
|
|
self.stdout.write(f' Successfully trained: {total_trained}')
|
|
self.stdout.write(f' Failed: {total_failed}')
|
|
|
|
if total_trained > 0:
|
|
self.stdout.write(
|
|
self.style.SUCCESS(f'✓ Training completed successfully')
|
|
)
|
|
else:
|
|
self.stdout.write(
|
|
self.style.WARNING('⚠ No models were successfully trained')
|
|
)
|
|
|
|
except Exception as e:
|
|
raise CommandError(f'Error executing command: {str(e)}')
|