Source code for neuroscope.viz.plots

"""
NeuroScope Visualization Module
High-quality plotting tools for neural network training analysis.
"""

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import PillowWriter


[docs] class Visualizer: """ High quality visualization tool for neural network training analysis. Provides comprehensive plotting capabilities for analyzing training dynamics, network behavior, and diagnostic information. Creates professional-grade figures suitable for research publications and presentations. Args: hist (dict): Complete training history from model.fit() containing: - history: Training/validation metrics per epoch - weights/biases: Final network parameters - activations/gradients: Sample network internals - *_stats_over_epochs: Statistical evolution during training Attributes: hist (dict): Complete training history data. history (dict): Training metrics (loss, accuracy) evolution. weights/biases: Final network parameters. activations/gradients: Representative network internals. Example: >>> from neuroscope.viz import Visualizer >>> history = model.fit(X_train, y_train, epochs=100) >>> viz = Visualizer(history) >>> viz.plot_learning_curves() >>> viz.plot_activation_distribution() >>> viz.plot_gradient_flow() """
[docs] def __init__(self, hist): """ Initialize visualizer with comprehensive training history. Sets up visualization infrastructure and applies publication-quality styling to all plots. Automatically extracts relevant data components for different types of analysis. Args: hist (dict): Training history from model.fit() containing all training statistics, network states, and diagnostic information. """ self.hist = hist self.method = self.hist.get("method") if self.method == "fit_fast": self.history = hist["history"] self.weights = hist.get("weights", None) self.biases = hist.get("biases", None) self.metric = hist.get("metric", "accuracy") self.metric_display_name = hist.get("metric_display_name", "Accuracy") else: self.history = hist["history"] self.weights = hist.get("weights", None) self.biases = hist.get("biases", None) self.activations = hist.get("activations", {}) self.gradients = hist.get("gradients", {}) self.weight_stats_over_epochs = hist.get("weight_stats_over_epochs", {}) self.activation_stats_over_epochs = hist.get( "activation_stats_over_epochs", {} ) self.gradient_stats_over_epochs = hist.get("gradient_stats_over_epochs", {}) self.epoch_distributions = hist.get("epoch_distributions", {}) self.gradient_norms = hist.get("gradient_norms_over_epochs", {}) self.weight_update_ratios = hist.get("weight_update_ratios_over_epochs", {}) self.metric = hist.get("metric", "accuracy") self.metric_display_name = hist.get("metric_display_name", "Accuracy") self.epochs_ = hist.get("epochs", None) self._setup_style() self._setup_colors()
def _setup_colors(self): """Define consistent color palette for all plots.""" self.colors = { "train": "#1f77b4", "validation": "#ff7f0e", "layers": [ "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf", ], "reference": { "good": "#2ca02c", "warning": "#ff7f0e", "critical": "#d62728", }, } # Line styles self.line_style = {"width": 1.2, "alpha": 0.8, "markersize": 3, "markevery": 2} def _setup_style(self): """Configure publication-quality matplotlib styling.""" plt.style.use("default") plt.rcParams.update( { "font.family": "serif", "font.serif": ["Times New Roman", "Times", "DejaVu Serif"], "font.size": 10, "axes.linewidth": 0.8, "axes.spines.left": True, "axes.spines.bottom": True, "axes.spines.top": False, "axes.spines.right": False, "xtick.direction": "in", "ytick.direction": "in", "xtick.major.size": 3, "ytick.major.size": 3, "xtick.minor.size": 1.5, "ytick.minor.size": 1.5, "legend.frameon": False, "figure.facecolor": "white", } ) def _configure_plot(self, title=None, xlabel=None, ylabel=None, figsize=(8, 6)): """Create and configure a single plot.""" fig, ax = plt.subplots(figsize=figsize) if title: ax.set_title(title, fontsize=11, fontweight="normal") if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) ax.grid(True, alpha=0.3, linewidth=0.5) ax.tick_params(labelsize=9) return fig, ax
[docs] def plot_learning_curves( self, figsize=(9, 4), ci=False, markers=True, save_path=None, metric="accuracy" ): """ Plot training and validation learning curves for regular fit() results. Creates highest quality subplot showing loss and metric evolution during training. Automatically detects available data and applies consistent styling with optional confidence intervals. Note: For fit_fast() results, use plot_curves_fast() instead. Args: figsize (tuple[int, int], optional): Figure dimensions (width, height). Defaults to (9, 4). ci (bool, optional): Whether to add confidence intervals using moving window statistics. Only available for regular fit() results. Defaults to False. markers (bool, optional): Whether to show markers on line plots. Defaults to True. save_path (str, optional): Path to save the figure. Defaults to None. metric (str, optional): Name of the metric for y-axis label. Defaults to 'accuracy'. Example: >>> viz.plot_learning_curves(figsize=(10, 5), ci=True, save_path='curves.png') """ if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) if self.epochs_: epochs = self.epochs_ else: epochs = np.arange(len(self.history["train_loss"])) + 1 marker_settings = { "train": ( {"marker": "o", "markersize": 3, "markevery": 2} if markers else {"marker": None} ), "val": ( {"marker": "s", "markersize": 3, "markevery": 2} if markers else {"marker": None} ), } ax1.plot( epochs, self.history["train_loss"], label="Training", linewidth=1.2, color="#1f77b4", **marker_settings["train"], ) ax1.plot( epochs, self.history["val_loss"], label="Validation", linewidth=1.2, color="#ff7f0e", **marker_settings["val"], ) if ci and len(epochs) > 3: self._add_confidence_intervals( ax1, epochs, self.history["train_loss"], "#1f77b4" ) self._add_confidence_intervals( ax1, epochs, self.history["val_loss"], "#ff7f0e" ) ax1.set_title("Loss", fontsize=11, fontweight="normal") ax1.set_xlabel("Epoch") ax1.set_ylabel("Loss") ax1.legend(fontsize=9, loc="upper right") ax1.grid(True, alpha=0.3, linewidth=0.5) ax1.tick_params(labelsize=9) ax2.plot( epochs, self.history["train_acc"], label="Training", linewidth=1.2, color="#1f77b4", **marker_settings["train"], ) ax2.plot( epochs, self.history["val_acc"], label="Validation", linewidth=1.2, color="#ff7f0e", **marker_settings["val"], ) if ci and len(epochs) > 3: self._add_confidence_intervals( ax2, epochs, self.history["train_acc"], "#1f77b4" ) self._add_confidence_intervals( ax2, epochs, self.history["val_acc"], "#ff7f0e" ) ax2.set_title(self.metric_display_name, fontsize=11, fontweight="normal") ax2.set_xlabel("Epoch") if metric is not None: ax2.set_ylabel(str(metric).title()) else: ax2.set_ylabel(self.metric_display_name) ax2.legend(fontsize=9, loc="lower right") ax2.grid(True, alpha=0.3, linewidth=0.5) ax2.tick_params(labelsize=9) plt.tight_layout(pad=2.0) if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") plt.show()
def _add_confidence_intervals(self, ax, x, y, color, alpha=0.2): """Add confidence intervals to plots using moving average smoothing.""" if len(y) < 3: return y = np.array(y) window = min(5, len(y) // 3) if window < 2: return padded_y = np.pad(y, window // 2, mode="edge") ci_upper = [] ci_lower = [] for i in range(len(y)): window_data = padded_y[i : i + window] mean_val = np.mean(window_data) std_val = np.std(window_data) ci_upper.append(mean_val + 1.96 * std_val / np.sqrt(len(window_data))) ci_lower.append(mean_val - 1.96 * std_val / np.sqrt(len(window_data))) x_clean = [i for i, val in enumerate(x) if not np.isnan(val)] ax.fill_between(x_clean, ci_lower, ci_upper, color=color, alpha=alpha)
[docs] def plot_curves_fast(self, figsize=(10, 4), markers=True, save_path=None): """ Plot learning curves for fit_fast() results. Args: figsize (tuple, optional): Figure size. Defaults to (10, 4). markers (bool, optional): Show actual data points. Defaults to True. save_path (str, optional): Save path. Defaults to None. """ if self.method != "fit_fast": print("Use plot_learning_curves() for regular fit() results.") return fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) # Get clean data from fit_fast (no None values) epochs = self.history["epochs"] train_loss = self.history["train_loss"] train_acc = self.history["train_acc"] val_loss = self.history.get("val_loss", []) val_acc = self.history.get("val_acc", []) # Plot settings line_kw = {"linewidth": 1.2} marker_kw = {"marker": "o", "markersize": 3} if markers else {"marker": None} # Plot Loss ax1.plot( epochs, train_loss, label="Training", color="#1f77b4", **line_kw, **marker_kw, ) if val_loss: ax1.plot( epochs, val_loss, label="Validation", color="#ff7f0e", **line_kw, **marker_kw, ) ax1.set_title("Loss", fontsize=11, fontweight="normal") ax1.set_xlabel("Epoch") ax1.set_ylabel("Loss") ax1.legend(fontsize=9, loc="upper right") ax1.grid(True, alpha=0.3, linewidth=0.5) ax1.tick_params(labelsize=9) # Plot Accuracy ax2.plot( epochs, train_acc, label="Training", color="#1f77b4", **line_kw, **marker_kw ) if val_acc: ax2.plot( epochs, val_acc, label="Validation", color="#ff7f0e", **line_kw, **marker_kw, ) ax2.set_title(self.metric_display_name, fontsize=11, fontweight="normal") ax2.set_xlabel("Epoch") ax2.set_ylabel(self.metric_display_name) ax2.legend(fontsize=9, loc="lower right") ax2.grid(True, alpha=0.3, linewidth=0.5) ax2.tick_params(labelsize=9) plt.tight_layout(pad=2.0) if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") plt.show()
[docs] def plot_activation_hist( self, epoch=None, figsize=(9, 4), kde=False, last_layer=False, save_path=None ): """ Plot activation value distributions across network layers. Visualizes the distribution of activation values for each layer at a specific epoch, aggregated from all mini-batches. Useful for detecting activation saturation, dead neurons, and distribution shifts during training. Args: epoch (int, optional): Specific epoch to plot. If None, uses last epoch. Defaults to None. figsize (tuple[int, int], optional): Figure dimensions. Defaults to (9, 4). kde (bool, optional): Whether to use KDE-style smoothing for smoother curves. Defaults to False. last_layer (bool, optional): Whether to include output layer. Defaults to False. save_path (str, optional): Path to save the figure. Defaults to None. Example: >>> viz.plot_activation_hist(epoch=50, kde=True, save_path='activations.png') """ if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return return self._plot_epoch_distribution( "activations", "Activation Distributions (Epoch-Agg)", "Activation Value", epoch, figsize, kde, last_layer, save_path, )
[docs] def plot_gradient_hist( self, epoch=None, figsize=(9, 4), kde=False, last_layer=False, save_path=None ): """ Plot gradient value distributions across network layers. Visualizes gradient distributions to detect vanishing/exploding gradient problems, gradient flow issues, and training stability. Shows zero-line reference for assessing gradient symmetry and magnitude. Args: epoch (int, optional): Specific epoch to plot. If None, uses last epoch. Defaults to None. figsize (tuple[int, int], optional): Figure dimensions. Defaults to (9, 4). kde (bool, optional): Whether to use KDE-style smoothing. Defaults to False. last_layer (bool, optional): Whether to include output layer gradients. Defaults to False. save_path (str, optional): Path to save the figure. Defaults to None. Note: Gradient distributions should be roughly symmetric around zero for healthy training. Very narrow distributions may indicate vanishing gradients, while very wide distributions may indicate exploding gradients. Example: >>> viz.plot_gradient_hist(epoch=25, kde=True, save_path='gradients.png') """ if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return return self._plot_epoch_distribution( "gradients", "Gradient Distributions (Epoch-Agg)", "Gradient Value", epoch, figsize, kde, last_layer, save_path, )
[docs] def plot_weight_hist( self, epoch=None, figsize=(9, 4), kde=False, last_layer=False, save_path=None ): """Uses aggregated samples from all mini-batches within an epoch to create representative distributions. Shows weight evolution patterns. Args: epoch: Specific epoch to plot (default: last epoch) figsize: Figure size tuple kde: Whether to use KDE-style smoothing last_layer: Whether to include output layer (default: False, hidden layers only) save_path: Path to save figure """ if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return return self._plot_epoch_distribution( "weights", "Weight Distributions (Epoch-Agg)", "Weight Value", epoch, figsize, kde, last_layer, save_path, )
def _plot_epoch_distribution( self, data_type, title, xlabel, epoch, figsize, kde, last_layer, save_path ): """ Internal method to plot epoch-aggregated distributions. Uses research-validated approach: 1. Aggregates all mini-batch samples from an epoch 2. Creates smooth KDE-like distributions 3. Layer-specific coloring and styling 4. Zero line reference for gradient analysis """ fig, ax = self._configure_plot(title, xlabel, "Density", figsize) if not self.epoch_distributions or data_type not in self.epoch_distributions: ax.text( 0.5, 0.5, f"No epoch {data_type} data available\n" "Requires training with MLP.fit() method", transform=ax.transAxes, ha="center", va="center", fontsize=12, ) plt.show() return data_dict = self.epoch_distributions[data_type] num_epochs = len(list(data_dict.values())[0]) if data_dict else 0 if num_epochs == 0: ax.text( 0.5, 0.5, f"No {data_type} data collected during training", transform=ax.transAxes, ha="center", va="center", fontsize=12, ) plt.show() return target_epoch = epoch if epoch is not None else num_epochs - 1 if target_epoch >= num_epochs or target_epoch < 0: target_epoch = num_epochs - 1 layers_to_plot = list(data_dict.keys()) if not last_layer: layers_to_plot = layers_to_plot[:-1] layer_colors = self.colors["layers"][: len(layers_to_plot)] for i, (layer_key, color) in enumerate(zip(layers_to_plot, layer_colors)): if target_epoch < len(data_dict[layer_key]): epoch_samples = data_dict[layer_key][target_epoch] if len(epoch_samples) > 0: bins = min(50, max(15, int(np.sqrt(len(epoch_samples))))) if kde: counts, bin_edges = np.histogram( epoch_samples, bins=bins, density=True ) bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 x_smooth = np.linspace( bin_centers.min(), bin_centers.max(), len(bin_centers) * 3 ) y_smooth = np.interp(x_smooth, bin_centers, counts) ax.plot( x_smooth, y_smooth, color=color, linewidth=2.0, alpha=0.9, label=f"Layer {i + 1} (n={len(epoch_samples)})", ) ax.fill_between(x_smooth, y_smooth, alpha=0.3, color=color) else: # Standard histogram ax.hist( epoch_samples, bins=bins, density=True, alpha=0.7, color=color, label=f"Layer {i + 1} (n={len(epoch_samples)})", edgecolor="black", linewidth=0.5, ) if data_type in ["gradients", "weights"]: ax.axvline( 0, color="red", linestyle="--", alpha=0.8, linewidth=1.5, label="Zero line", ) ax.legend(fontsize=9, bbox_to_anchor=(1.02, 1), loc="upper left") ax.grid(True, alpha=0.3) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") plt.show()
[docs] def plot_activation_stats( self, activation_stats=None, figsize=(12, 4), save_path=None, reference_lines=False, ): """ Plot activation statistics over time with both mean and std. Args: activation_stats: Dict of layer activation stats OR None to use class data Format: {'layer_0': {'mean': [...], 'std': [...]}, ...} figsize: Figure size tuple save_path: Path to save figure """ if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return fig, ax = self._configure_plot( "Activation Statistics Over Epochs", "Epoch", "Mean Activation", figsize ) data_source = ( activation_stats if activation_stats is not None else self.activation_stats_over_epochs ) if not data_source: ax.text( 0.5, 0.5, "No activation stats data available\nPlease run model.fit() to generate statistics", transform=ax.transAxes, ha="center", va="center", fontsize=12, ) plt.show() return first_layer = list(data_source.values())[0] has_mean_std = isinstance(first_layer, dict) and "mean" in first_layer if has_mean_std: epochs = np.arange(len(first_layer["mean"])) + 1 else: epochs = np.arange(len(first_layer)) + 1 layer_colors = self.colors["layers"][: len(data_source)] for i, (layer_name, color) in enumerate(zip(data_source.keys(), layer_colors)): if has_mean_std: mean_values = data_source[layer_name]["mean"] ax.plot( epochs, mean_values, label=f"Layer {i + 1} Mean", color=color, linewidth=self.line_style["width"], alpha=self.line_style["alpha"], marker="o", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], linestyle="-", ) std_values = data_source[layer_name]["std"] ax.plot( epochs, std_values, label=f"Layer {i + 1} Std", color=color, linewidth=self.line_style["width"], alpha=self.line_style["alpha"] * 0.7, marker="s", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], linestyle="--", ) else: y_values = data_source[layer_name] ax.plot( epochs, y_values, label=f"Layer {i + 1}", color=color, linewidth=self.line_style["width"], alpha=self.line_style["alpha"], marker="o", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], ) # Add healthy range reference lines for activation magnitudes during training # Xavier (2010), He (2015), Batch Norm papers: 0.01-5.0 is healthy range # (Applies to both mean and std of activations during training) if reference_lines: ax.axhline( y=0.01, color=self.colors["reference"]["good"], linestyle="--", alpha=0.7, label="Healthy Min (0.01)", ) ax.axhline( y=5.0, color=self.colors["reference"]["good"], linestyle="--", alpha=0.7, label="Healthy Max (5.0)", ) ax.legend(fontsize=9, bbox_to_anchor=(1.02, 1), loc="upper left") plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") plt.show()
[docs] def plot_gradient_stats( self, figsize=(12, 4), save_path=None, reference_lines=False ): """Plot gradient statistics over time with both mean and std.""" if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return fig, ax = self._configure_plot( "Gradient Statistics Over Epochs", "Epoch", "Mean |Gradient|", figsize ) if not self.gradient_stats_over_epochs: ax.text( 0.5, 0.5, "No gradient stats data available\nPlease run model.fit() to generate statistics", transform=ax.transAxes, ha="center", va="center", fontsize=12, ) plt.show() return first_layer = list(self.gradient_stats_over_epochs.values())[0] has_mean_std = isinstance(first_layer, dict) and "mean" in first_layer if has_mean_std: epochs = np.arange(len(first_layer["mean"])) + 1 else: epochs = np.arange(len(first_layer)) + 1 layer_colors = self.colors["layers"][: len(self.gradient_stats_over_epochs)] for i, (layer_name, color) in enumerate( zip(self.gradient_stats_over_epochs.keys(), layer_colors) ): if has_mean_std: mean_values = self.gradient_stats_over_epochs[layer_name]["mean"] ax.plot( epochs, mean_values, label=f"Layer {i + 1} Mean", color=color, linewidth=self.line_style["width"], alpha=self.line_style["alpha"], marker="s", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], linestyle="-", ) std_values = self.gradient_stats_over_epochs[layer_name]["std"] ax.plot( epochs, std_values, label=f"Layer {i + 1} Std", color=color, linewidth=self.line_style["width"], alpha=self.line_style["alpha"] * 0.7, marker="D", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], linestyle="--", ) else: y_values = self.gradient_stats_over_epochs[layer_name] ax.plot( epochs, y_values, label=f"Layer {i + 1}", color=color, linewidth=self.line_style["width"], alpha=self.line_style["alpha"], marker="s", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], ) # Add healthy range reference lines for gradient magnitudes during training # Bengio (1994), Pascanu (2013), Goodfellow (2016): 1e-4 to 1e-1 is healthy range # (Applies to both mean and std of gradients during training) if reference_lines: ax.axhline( y=1e-4, color=self.colors["reference"]["good"], linestyle="--", alpha=0.7, label="Healthy Min (1e-4)", ) ax.axhline( y=1e-1, color=self.colors["reference"]["good"], linestyle="--", alpha=0.7, label="Healthy Max (1e-1)", ) ax.set_yscale("log") ax.legend(fontsize=9, bbox_to_anchor=(1.02, 1), loc="upper left") plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") plt.show()
[docs] def plot_weight_stats(self, figsize=(12, 4), save_path=None, reference_lines=False): """Plot weight statistics over time with both mean and std.""" if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return fig, ax = self._configure_plot( "Weight Statistics Over Epochs", "Epoch", "Mean |Weight|", figsize ) if not self.weight_stats_over_epochs: ax.text( 0.5, 0.5, "No weight stats data available\nPlease run model.fit() to generate statistics", transform=ax.transAxes, ha="center", va="center", fontsize=12, ) plt.show() return first_layer = list(self.weight_stats_over_epochs.values())[0] has_mean_std = isinstance(first_layer, dict) and "mean" in first_layer if has_mean_std: epochs = np.arange(len(first_layer["mean"])) + 1 else: epochs = np.arange(len(first_layer)) + 1 layer_colors = self.colors["layers"][: len(self.weight_stats_over_epochs)] for i, (layer_name, color) in enumerate( zip(self.weight_stats_over_epochs.keys(), layer_colors) ): if has_mean_std: mean_values = self.weight_stats_over_epochs[layer_name]["mean"] ax.plot( epochs, mean_values, label=f"Layer {i + 1} Mean", color=color, linewidth=self.line_style["width"], alpha=self.line_style["alpha"], marker="^", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], linestyle="-", ) std_values = self.weight_stats_over_epochs[layer_name]["std"] ax.plot( epochs, std_values, label=f"Layer {i + 1} Std", color=color, linewidth=self.line_style["width"], alpha=self.line_style["alpha"] * 0.7, marker="v", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], linestyle="--", ) else: y_values = self.weight_stats_over_epochs[layer_name] ax.plot( epochs, y_values, label=f"Layer {i + 1}", color=color, linewidth=self.line_style["width"], alpha=self.line_style["alpha"], marker="^", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], ) # Add healthy range reference lines for TRAINED neural network weights # ResNet, VGG, modern CNN literature: 0.05-0.8 is healthy range for trained weights # (NOT initialization ranges - these are post-training weight magnitudes) if reference_lines: ax.axhline( y=0.05, color=self.colors["reference"]["good"], linestyle="--", alpha=0.7, label="Healthy Min (0.05)", ) ax.axhline( y=0.8, color=self.colors["reference"]["good"], linestyle="--", alpha=0.7, label="Healthy Max (0.8)", ) ax.legend(fontsize=9, bbox_to_anchor=(1.02, 1), loc="upper left") plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") plt.show()
[docs] def plot_update_ratios( self, update_ratios=None, figsize=(12, 4), save_path=None, reference_lines=False ): """ Plot weight update ratios per layer across epochs. Args: update_ratios: Dict of layer update ratios (optional - uses collected data if None) Format: {'layer_0': [ratio_epoch_0, ratio_epoch_1, ...], ...} figsize: Figure size tuple save_path: Path to save figure """ if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return fig, ax = self._configure_plot( "Weight Update Ratios Over Epochs", "Epoch", "Update Ratio", figsize ) data_source = ( update_ratios if update_ratios is not None else self.weight_update_ratios ) if not data_source: ax.text( 0.5, 0.5, "No weight update ratio data available\nPlease run model.fit() to generate statistics", transform=ax.transAxes, ha="center", va="center", fontsize=12, ) plt.show() return epochs = np.arange(len(list(data_source.values())[0])) + 1 layer_colors = self.colors["layers"][: len(data_source)] for i, (layer_name, color) in enumerate(zip(data_source.keys(), layer_colors)): y_values = data_source[layer_name] ax.plot( epochs, y_values, label=f"Layer {i + 1}", color=color, linewidth=self.line_style["width"], marker="o", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], ) if reference_lines: ax.axhline( y=1e-4, color=self.colors["reference"]["good"], linestyle="--", alpha=0.7, label="Healthy Min (1e-4)", ) ax.axhline( y=1e-2, color=self.colors["reference"]["good"], linestyle="--", alpha=0.7, label="Healthy Max (1e-2)", ) ax.set_yscale("log") ax.legend(fontsize=9, bbox_to_anchor=(1.02, 1), loc="upper left") if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") plt.show()
[docs] def plot_gradient_norms( self, gradient_norms=None, figsize=(12, 4), save_path=None, reference_lines=False, ): """Plot gradient norms per layer over epochs.""" if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return fig, ax = self._configure_plot( "Gradient Norms Over Epochs", "Epoch", "Gradient Norm", figsize ) data_source = ( gradient_norms if gradient_norms is not None else self.gradient_norms ) if not data_source: print("No gradient norm data available") return layer_colors = self.colors["layers"][: len(data_source)] epochs = np.arange(len(list(data_source.values())[0])) + 1 for i, (layer_name, color) in enumerate(zip(data_source.keys(), layer_colors)): norms = data_source[layer_name] if norms: ax.plot( epochs, norms, label=f"Layer {i + 1}", color=color, linewidth=self.line_style["width"], marker="s", markersize=self.line_style["markersize"], markevery=self.line_style["markevery"], ) # Add reference lines based on authentic gradient clipping literature # Pascanu et al. (2013), Sutskever et al. (2014): 1.0 is standard clipping threshold # Bengio et al. (1994): Below 0.01 indicates vanishing gradients if reference_lines: ax.axhline( y=1.0, color=self.colors["reference"]["good"], linestyle="--", alpha=0.7, label="Clipping Threshold (1.0)", ) ax.axhline( y=0.01, color=self.colors["reference"]["warning"], linestyle="--", alpha=0.7, label="Vanishing Gradient Warning (0.01)", ) ax.set_yscale("log") ax.legend(fontsize=9, bbox_to_anchor=(1.02, 1), loc="upper left") if save_path: plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") plt.show()
[docs] def plot_training_animation(self, bg="dark", save_path=None): """ Creates a comprehensive 4-panel GIF animation showing: 1. Loss curves over time 2. Accuracy curves over time 3. Current metrics bar chart 4. Gradient flow analysis Speed automatically adjusts based on epoch count for smooth motion feel. Args: bg: Theme ('dark' or 'light') save_path: Path to save GIF (defaults to 'mlp_training_animation.gif') Returns: Path to saved GIF file """ if self.method == "fit_fast": print("The function is only accessible to .fit() method.") return epochs = len(self.history.get("train_loss", [])) if epochs == 0: print("No training data found - cannot create animation") return None fps = 3 if epochs < 50 else 5 if epochs <= 100 else 8 colors = { "primary": "#2E86AB", "secondary": "#F18F01", "accent": "#28D631", "success": "#FF2727", "secondary_accent": "#FFC062", "primary_accent": "#63CBF8", "background": "#0A0E27" if bg == "dark" else "#FFFFFF", "text": "#FFFFFF" if bg == "dark" else "#000000", "grid": "#333333" if bg == "dark" else "#CCCCCC", "spine": "#555555" if bg == "dark" else "#999999", } fig = plt.figure(figsize=(18, 10)) fig.patch.set_facecolor(colors["background"]) axes = [ plt.subplot(2, 3, 1), # Loss curves plt.subplot(2, 3, 2), # Accuracy curves plt.subplot(2, 3, 3), # Metrics bar chart plt.subplot(2, 1, 2), # Gradient flow ] plt.subplots_adjust(hspace=0.3, wspace=0.3) def style_axis(ax): """Apply consistent styling to axis""" ax.set_facecolor(colors["background"]) ax.tick_params(colors=colors["text"]) for spine in ax.spines.values(): spine.set_color(colors["spine"]) def style_legend(legend): """Apply consistent styling to legend""" legend.get_frame().set_facecolor(colors["background"]) legend.get_frame().set_edgecolor(colors["spine"]) for text in legend.get_texts(): text.set_color(colors["text"]) for ax in axes: style_axis(ax) train_loss, val_loss = self.history.get("train_loss", []), self.history.get( "val_loss", [] ) train_acc, val_acc = self.history.get("train_acc", []), self.history.get( "val_acc", [] ) def animate(frame): for ax in axes: ax.clear() style_axis(ax) current_epoch = frame + 1 epochs_range = list(range(1, current_epoch + 1)) # 1. Loss Evolution axes[0].set_title( "Loss Evolution", color=colors["text"], fontweight="bold", fontsize=15 ) if len(train_loss) >= current_epoch: train_losses_current = train_loss[:current_epoch] axes[0].plot( epochs_range, train_losses_current, color=colors["primary"], linewidth=2.5, label="Train Loss", alpha=0.9, ) else: train_losses_current = [] if len(val_loss) >= current_epoch: val_losses_current = val_loss[:current_epoch] axes[0].plot( epochs_range, val_losses_current, color=colors["secondary"], linewidth=2.5, label="Val Loss", alpha=0.9, ) else: val_losses_current = [] # Generalization gap visualization if train_losses_current and val_losses_current: min_length = min(len(train_losses_current), len(val_losses_current)) epochs_common = epochs_range[:min_length] train_common = train_losses_current[:min_length] val_common = val_losses_current[:min_length] axes[0].fill_between( epochs_common, train_common, val_common, where=np.array(val_common) > np.array(train_common), color=colors["secondary_accent"], alpha=0.4, interpolate=True, ) axes[0].fill_between( epochs_common, train_common, val_common, where=np.array(val_common) <= np.array(train_common), color=colors["primary_accent"], alpha=0.3, interpolate=True, ) axes[0].set_xlabel("Epoch", color=colors["text"]) axes[0].set_ylabel("Loss", color=colors["text"]) axes[0].grid(True, alpha=0.3, color=colors["grid"]) style_legend(axes[0].legend(loc="upper right")) # 2. Metric Evolution axes[1].set_title( f"{self.metric_display_name} Evolution", color=colors["text"], fontweight="bold", fontsize=16, ) if len(train_acc) >= current_epoch: axes[1].plot( epochs_range, train_acc[:current_epoch], color=colors["primary"], linewidth=2, label=f"Train {self.metric_display_name}", ) if len(val_acc) >= current_epoch: axes[1].plot( epochs_range, val_acc[:current_epoch], color=colors["secondary"], linewidth=2, label=f"Val {self.metric_display_name}", ) axes[1].set_xlabel("Epoch", color=colors["text"]) axes[1].set_ylabel(self.metric_display_name, color=colors["text"]) axes[1].grid(True, alpha=0.3, color=colors["grid"]) style_legend(axes[1].legend()) # 3. Current Metrics Bar Chart axes[2].set_title( "Current Metrics", color=colors["text"], fontweight="bold", fontsize=16 ) metrics = [ "Train Loss", "Val Loss", f"Train {self.metric_display_name}", f"Val {self.metric_display_name}", ] data_sources = [train_loss, val_loss, train_acc, val_acc] values = [ data[current_epoch - 1] if len(data) >= current_epoch else 0 for data in data_sources ] bar_colors = [colors["primary"], colors["secondary"]] * 2 max_val = max(values) if any(values) else 1 for i, (value, color) in enumerate(zip(values, bar_colors)): if value > 0: axes[2].bar( i, value, width=0.6, color=color, alpha=0.7, edgecolor="none" ) axes[2].text( i, value + max_val * 0.02, f"{value:.3f}", ha="center", va="bottom", color=colors["text"], fontweight="bold", fontsize=8, ) axes[2].set_xlim(-0.5, len(metrics) - 0.5) axes[2].set_ylim(0, max_val * 1.15) axes[2].set_xticks(range(len(metrics))) axes[2].set_xticklabels(metrics, color=colors["text"]) axes[2].set_ylabel("Value", color=colors["text"]) # 4. Gradient Flow Analysis axes[3].set_title( "Gradient Flow Across Layers", color=colors["text"], fontweight="bold", fontsize=16, ) if ( hasattr(self, "gradient_stats_over_epochs") and self.gradient_stats_over_epochs and current_epoch > 0 ): layer_colors = [ colors["primary"], colors["secondary"], colors["accent"], ] for i, (layer_key, layer_stats) in enumerate( self.gradient_stats_over_epochs.items() ): if isinstance(layer_stats, dict) and "mean" in layer_stats: if current_epoch <= len(layer_stats["mean"]): gradient_means = layer_stats["mean"][:current_epoch] color = layer_colors[i % len(layer_colors)] axes[3].plot( epochs_range, gradient_means, color=color, linewidth=2, alpha=0.8, label=f"Layer {i}", marker="o", markersize=3, ) axes[3].set_xlabel("Epoch", color=colors["text"]) axes[3].set_ylabel("Mean Gradient Magnitude", color=colors["text"]) axes[3].set_yscale("log") axes[3].set_ylim(1e-4, 1e0) axes[3].grid(True, alpha=0.2, color=colors["grid"]) style_legend(axes[3].legend(loc="upper right")) # Reference lines for healthy gradient ranges axes[3].axhline( y=1e-7, color=colors["success"], linestyle=":", alpha=0.5, linewidth=1, ) axes[3].axhline( y=1e-2, color=colors["success"], linestyle=":", alpha=0.5, linewidth=1, ) else: axes[3].text( 0.5, 0.5, "Gradient Flow Analysis\n(Gradient stats not available)", ha="center", va="center", transform=axes[3].transAxes, color=colors["text"], fontsize=12, alpha=0.7, ) fig.suptitle( f"MLP Training Animation - Epoch {current_epoch}/{epochs}", fontsize=22, fontweight="bold", color=colors["text"], ) # Create and save animation anim = animation.FuncAnimation( fig, animate, frames=epochs, interval=1000 // fps, repeat=True ) output_path = save_path if save_path else "mlp_training_animation.gif" print("Building Animation - this may take a while...") try: anim.save(output_path, writer=PillowWriter(fps=fps)) print(f"Saved animation to: {output_path}") except Exception as e: print(f"Error saving animation: {e}") plt.close(fig)