Source code for neuroscope.diagnostics.posttraining

"""
Post-Training Evaluation for NeuroScope MLP Framework
Focused post-training evaluation tools for neural network assessment after training.
"""

import time
from typing import Any, Dict, List, Optional

import numpy as np


[docs] class PostTrainingEvaluator: """ Comprehensive post-training evaluation system for neural networks. Provides thorough analysis of trained model performance including robustness testing, performance metrics evaluation, and diagnostic assessments. Designed to validate model quality and identify potential deployment issues after training completion. Args: model: Trained and compiled MLP model instance with initialized weights. Attributes: model: Reference to the trained neural network model. results (dict): Cached evaluation results from various assessments. Example: >>> from neuroscope.diagnostics import PostTrainingEvaluator >>> model = MLP([784, 128, 10]) >>> model.compile(lr=1e-3) >>> history = model.fit(X_train, y_train, epochs=100) >>> evaluator = PostTrainingEvaluator(model) >>> evaluator.evaluate(X_test, y_test) >>> # Access detailed results >>> robustness = evaluator.evaluate_robustness(X_test, y_test) >>> performance = evaluator.evaluate_performance(X_test, y_test) """
[docs] def __init__(self, model): """Initialize evaluator with a trained model.""" if not hasattr(model, "weights") or not hasattr(model, "biases"): raise ValueError("Model must be weight initialized.") if not getattr(model, "compiled", False): raise ValueError("Model must be compiled.") self.model = model self.results = {}
[docs] def evaluate_robustness( self, X: np.ndarray, y: np.ndarray, noise_levels: List[float] = None ) -> Dict[str, Any]: """Evaluate model robustness against Gaussian noise.""" if noise_levels is None: noise_levels = [0.01, 0.05, 0.1, 0.2] try: baseline_loss, baseline_accuracy = self.model.evaluate(X, y) baseline_predictions = self.model.predict(X) except Exception as e: return { "status": "ERROR", "error": str(e), "note": "Failed to compute baseline performance", } robustness_scores = [] for noise_level in noise_levels: try: X_noisy = X + np.random.normal(0, noise_level, X.shape) noisy_loss, noisy_accuracy = self.model.evaluate(X_noisy, y) noisy_predictions = self.model.predict(X_noisy) accuracy_drop = baseline_accuracy - noisy_accuracy if baseline_predictions.shape[1] > 1: consistency = np.mean( np.argmax(baseline_predictions, axis=1) == np.argmax(noisy_predictions, axis=1) ) else: consistency = max( 0, np.corrcoef( baseline_predictions.flatten(), noisy_predictions.flatten() )[0, 1], ) accuracy_robustness = max( 0, 1 - (accuracy_drop / (baseline_accuracy + 1e-8)) ) robustness_scores.append((accuracy_robustness + consistency) / 2) except: pass overall_robustness = np.mean(robustness_scores) if robustness_scores else 0.0 if overall_robustness >= 0.8: status, note = "EXCELLENT", "Highly robust to noise" elif overall_robustness >= 0.6: status, note = "PASS", "Good noise robustness" elif overall_robustness >= 0.4: status, note = "WARN", "Moderate robustness" else: status, note = "FAIL", "Poor noise robustness" return { "baseline_accuracy": baseline_accuracy, "baseline_loss": baseline_loss, "overall_robustness": overall_robustness, "status": status, "note": note, }
[docs] def evaluate_performance(self, X: np.ndarray, y: np.ndarray) -> Dict[str, Any]: """Evaluate model performance metrics.""" try: # Warm up the model with a small prediction to stabilize timing if X.shape[0] > 1: _ = self.model.predict(X[:1]) # Take multiple timing measurements for more stable results times = [] for _ in range(3): start_time = time.time() predictions = self.model.predict(X) times.append(time.time() - start_time) # Use median time for more robust measurement prediction_time = sorted(times)[1] # median of 3 measurements loss, primary_accuracy = self.model.evaluate(X, y) # Ensure minimum time to avoid division by zero and unrealistic values min_time = 1e-6 # 1 microsecond minimum prediction_time = max(prediction_time, min_time) samples_per_second = X.shape[0] / prediction_time total_params = sum( w.size + b.size for w, b in zip(self.model.weights, self.model.biases) ) all_metrics = self._evaluate_all_metrics(X, y, predictions) # Status assessment if primary_accuracy >= 0.9 and samples_per_second >= 1000: status, note = "EXCELLENT", "High accuracy and fast inference" elif primary_accuracy >= 0.8: status, note = "PASS", "Good overall performance" elif primary_accuracy >= 0.6: status, note = "WARN", "Moderate performance" else: status, note = "FAIL", "Poor performance" return { "accuracy": primary_accuracy, "loss": loss, "samples_per_second": samples_per_second, "total_params": total_params, "all_metrics": all_metrics, "status": status, "note": note, } except Exception as e: return { "status": "ERROR", "error": str(e), "note": "Failed to evaluate performance", }
def _evaluate_all_metrics( self, X: np.ndarray, y: np.ndarray, predictions: np.ndarray ) -> Dict[str, float]: """Evaluate all available metrics from the metrics module.""" try: from neuroscope.mlp.metrics import Metrics except ImportError: try: from ..mlp.metrics import Metrics except ImportError as e: return {"error": f"Failed to import metrics: {str(e)}"} # Determine task type is_multiclass = predictions.shape[1] > 1 is_binary = predictions.shape[1] == 1 and len(np.unique(y)) <= 10 is_regression = not (is_multiclass or is_binary) metrics_results = {} try: # Regression metrics (for regression and binary classification only) if is_regression or is_binary: metrics_results.update( { "mse": Metrics.mse(y, predictions), "rmse": Metrics.rmse(y, predictions), "mae": Metrics.mae(y, predictions), } ) # Classification metrics if is_binary: metrics_results.update( { "accuracy_binary": Metrics.accuracy_binary(y, predictions), "precision": Metrics.precision(y, predictions), "recall": Metrics.recall(y, predictions), "f1_score": Metrics.f1_score(y, predictions), } ) elif is_multiclass: metrics_results.update( { "accuracy_multiclass": Metrics.accuracy_multiclass( y, predictions ), "precision": Metrics.precision(y, predictions, average="macro"), "recall": Metrics.recall(y, predictions, average="macro"), "f1_score": Metrics.f1_score(y, predictions, average="macro"), } ) # Regression-specific metrics if is_regression: metrics_results["r2_score"] = Metrics.r2_score(y, predictions) except Exception as e: metrics_results["error"] = str(e) return metrics_results
[docs] def evaluate_stability( self, X: np.ndarray, y: np.ndarray, n_samples: int = 100 ) -> Dict[str, Any]: """Evaluate prediction stability across similar inputs. (K Neighbor Approach)""" try: if X.shape[0] < 2: return { "status": "ERROR", "error": "Insufficient data", "note": "Need at least 2 samples", } n_test = min(n_samples, X.shape[0]) test_indices = np.random.choice(X.shape[0], n_test, replace=False) stability_scores = [] for idx in test_indices: x_ref = X[idx : idx + 1] pred_ref = self.model.predict(x_ref) distances = np.linalg.norm(X - x_ref, axis=1) distances[idx] = np.inf neighbor_idx = np.argmin(distances) if distances[neighbor_idx] == np.inf: continue pred_neighbor = self.model.predict(X[neighbor_idx : neighbor_idx + 1]) if pred_ref.shape[1] > 1: stability = ( 1.0 if np.argmax(pred_ref) == np.argmax(pred_neighbor) else 0.0 ) else: pred_distance = np.abs( pred_ref.flatten()[0] - pred_neighbor.flatten()[0] ) output_scale = np.std(y) + 1e-8 stability = np.exp(-pred_distance / output_scale) stability_scores.append(stability) overall_stability = np.mean(stability_scores) if stability_scores else 0.0 if overall_stability >= 0.8: status, note = "EXCELLENT", "Highly stable predictions" elif overall_stability >= 0.6: status, note = "PASS", "Good prediction stability" elif overall_stability >= 0.4: status, note = "WARN", "Moderate stability issues" else: status, note = "FAIL", "Poor prediction stability" return { "overall_stability": overall_stability, "status": status, "note": note, } except Exception as e: return { "status": "ERROR", "error": str(e), "note": "Failed to evaluate stability", }
[docs] def evaluate( self, X_test: np.ndarray, y_test: np.ndarray, X_train: Optional[np.ndarray] = None, y_train: Optional[np.ndarray] = None, ): """Run comprehensive model evaluation and generate summary report.""" print("=" * 80) print(" NEUROSCOPE POST-TRAINING EVALUATION") print("=" * 80) evaluations = [ ("Robustness", lambda: self.evaluate_robustness(X_test, y_test)), ("Performance", lambda: self.evaluate_performance(X_test, y_test)), ("Stability", lambda: self.evaluate_stability(X_test, y_test)), ] if X_train is not None and y_train is not None: evaluations.append( ( "Generalization", lambda: self.evaluate_generalization( X_train, y_train, X_test, y_test ), ) ) all_results = {} print(f"{'EVALUATION':<15} {'STATUS':<12} {'SCORE':<12} {'NOTE':<45}") print("-" * 80) for eval_name, eval_func in evaluations: try: result = eval_func() all_results[eval_name] = result status = result.get("status", "UNKNOWN") note = result.get("note", "-") score_keys = [ "overall_robustness", "accuracy", "generalization_score", "overall_stability", ] score = next( (f"{result[key]:.3f}" for key in score_keys if key in result), "N/A" ) print(f"{eval_name:<15} {status:<12} {score:<12} {note:<45}") except Exception as e: all_results[eval_name] = { "status": "ERROR", "error": str(e), "note": "Evaluation failed", } print( f"{eval_name:<15} {'ERROR':<12} {'N/A':<12} {'Evaluation failed':<45}" ) print("-" * 80) status_counts = {} for result in all_results.values(): status = result.get("status", "UNKNOWN") status_counts[status] = status_counts.get(status, 0) + 1 if status_counts.get("ERROR", 0) > 0: overall_status = "EVALUATION ERRORS" elif status_counts.get("FAIL", 0) > 0: overall_status = "ISSUES DETECTED" elif status_counts.get("WARN", 0) > 0: overall_status = "WARNINGS PRESENT" else: overall_status = "EVALUATION COMPLETE" pass_count = status_counts.get("PASS", 0) + status_counts.get("EXCELLENT", 0) print(f"OVERALL STATUS: {overall_status}") print(f"EVALUATIONS PASSED: {pass_count}/{len(evaluations)}") if ( "Performance" in all_results and all_results["Performance"].get("status") != "ERROR" ): self._display_metrics_evaluation(all_results["Performance"]) print("=" * 80) self.results = all_results
def _display_metrics_evaluation(self, performance_result: Dict[str, Any]): """Display metrics evaluation in a structured table format.""" all_metrics = performance_result.get("all_metrics", {}) if not all_metrics or "error" in all_metrics: return is_classification = any( metric in all_metrics for metric in [ "accuracy_binary", "accuracy_multiclass", "precision", "recall", "f1_score", ] ) task_type = "CLASSIFICATION" if is_classification else "REGRESSION" print("=" * 80) print(f" {task_type} METRICS") print("=" * 80) if is_classification: metric_order = [ ("accuracy_binary", "Accuracy"), ("accuracy_multiclass", "Accuracy"), ("precision", "Precision"), ("recall", "Recall"), ("f1_score", "F1-Score"), ] def get_status(key, value): return ( ("EXCELLENT", "Outstanding performance") if value >= 0.95 else ( ("PASS", "Good performance") if value >= 0.85 else ( ("WARN", "Moderate performance") if value >= 0.70 else ("FAIL", "Poor performance") ) ) ) else: metric_order = [ ("r2_score", "R² Score"), ("mae", "Mean Absolute Error"), ("mse", "Mean Squared Error"), ("rmse", "Root Mean Squared Error"), ] def get_status(key, value): if key == "r2_score": return ( ("EXCELLENT", "Outstanding fit") if value >= 0.9 else ( ("PASS", "Good fit") if value >= 0.7 else ( ("WARN", "Moderate fit") if value >= 0.5 else ("FAIL", "Poor fit") ) ) ) else: return ( ("EXCELLENT", "Very low error") if value <= 0.1 else ( ("PASS", "Low error") if value <= 0.3 else ( ("WARN", "Moderate error") if value <= 0.5 else ("FAIL", "High error") ) ) ) metrics_to_display = [] for metric_key, display_name in metric_order: if metric_key in all_metrics and isinstance( all_metrics[metric_key], (int, float) ): value = all_metrics[metric_key] status, note = get_status(metric_key, value) metrics_to_display.append((display_name, status, value, note)) print(f"{'METRIC':<20} {'STATUS':<12} {'SCORE':<12} {'NOTE':<40}") print("-" * 80) for metric_name, status, value, note in metrics_to_display: score_str = f"{value:.4f}" print(f"{metric_name:<20} {status:<12} {score_str:<12} {note:<40}") print("-" * 80) status_counts = {} for _, status, _, _ in metrics_to_display: status_counts[status] = status_counts.get(status, 0) + 1 metrics_pass_count = status_counts.get("PASS", 0) + status_counts.get( "EXCELLENT", 0 ) total_metrics = len(metrics_to_display) metrics_overall = ( "METRICS ISSUES DETECTED" if status_counts.get("FAIL", 0) > 0 else ( "SOME METRICS WARNINGS" if status_counts.get("WARN", 0) > 0 else "METRICS EVALUATION COMPLETE" ) ) print(f"METRICS STATUS: {metrics_overall}") print(f"METRICS PASSED: {metrics_pass_count}/{total_metrics}")