Coverage for src/pycse/sklearn/dpose.py: 77.70%
278 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
1"""A DPOSE (Direct Propagation of Shallow Ensembles) Neural Network model in JAX.
3Implementation based on:
4Kellner, M., & Ceriotti, M. (2024). Uncertainty quantification by direct
5propagation of shallow ensembles. Machine Learning: Science and Technology, 5(3), 035006.
7Key features:
8- Shallow ensemble architecture (only last layer differs across members)
9- CRPS loss for robust, calibrated uncertainty estimates (default)
10- Alternative NLL or MSE losses available
11- Post-hoc calibration on validation set
12- Ensemble propagation for derived quantities
14Example usage:
16 import jax
17 import numpy as np
18 import matplotlib.pyplot as plt
20 # Generate heteroscedastic data
21 key = jax.random.PRNGKey(19)
22 x = np.linspace(0, 1, 100)[:, None]
23 noise_level = 0.01 + 0.1 * x.ravel() # Increasing noise
24 y = x.ravel()**(1/3) + noise_level * jax.random.normal(key, (100,))
26 # Split into train/validation
27 from sklearn.model_selection import train_test_split
28 x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
30 # Train with DPOSE (uses CRPS loss and BFGS optimizer by default)
31 from pycse.sklearn.dpose import DPOSE
32 model = DPOSE(layers=(1, 15, 32))
33 model.fit(x_train, y_train, val_X=x_val, val_y=y_val)
35 # Or use a different optimizer (e.g., Adam)
36 model_adam = DPOSE(layers=(1, 15, 32), optimizer='adam')
37 model_adam.fit(x_train, y_train, val_X=x_val, val_y=y_val, learning_rate=1e-3)
39 # Or use Muon optimizer (state-of-the-art 2024)
40 model_muon = DPOSE(layers=(1, 15, 32), optimizer='muon')
41 model_muon.fit(x_train, y_train, val_X=x_val, val_y=y_val, learning_rate=0.02)
43 # Get predictions with uncertainty
44 y_pred, y_std = model.predict(x, return_std=True)
46 # Visualize
47 model.plot(x, y, distribution=True)
49 # For uncertainty propagation on derived quantities
50 ensemble_preds = model.predict_ensemble(x) # (n_samples, n_ensemble)
51 # Apply any function f to ensemble members
52 z_ensemble = f(ensemble_preds)
53 z_mean = z_ensemble.mean(axis=1)
54 z_std = z_ensemble.std(axis=1)
56Requires: flax, jaxopt, jax, scikit-learn
57"""
59import os
60import jax
63from jax import jit
64import jax.numpy as np
65from jax import value_and_grad
66import jaxopt
67import optax
68import matplotlib.pyplot as plt
69from sklearn.base import BaseEstimator, RegressorMixin
70from flax import linen as nn
71from flax.linen.initializers import xavier_uniform
73os.environ["JAX_ENABLE_X64"] = "True"
74jax.config.update("jax_enable_x64", True)
77class _NN(nn.Module):
78 """A flax neural network.
80 layers: a Tuple of integers specifying the network architecture.
81 - layers[0]: Input dimension (number of features)
82 - layers[1:-1]: Hidden layer sizes
83 - layers[-1]: Output dimension (ensemble size)
85 Example: layers=(5, 20, 32) creates:
86 - Input: 5 features
87 - Hidden: 20 neurons with activation
88 - Output: 32 ensemble members (no activation)
89 """
91 layers: tuple
92 activation: callable
94 @nn.compact
95 def __call__(self, x):
96 # Hidden layers (skip first element which is input dimension)
97 for i in self.layers[1:-1]:
98 x = nn.Dense(i, kernel_init=xavier_uniform())(x)
99 x = self.activation(x)
101 # Linear last layer where each row is a set of predictions
102 # The mean on axis=1 is the prediction
103 x = nn.Dense(self.layers[-1])(x)
104 return x
107class DPOSE(BaseEstimator, RegressorMixin):
108 """DPOSE: Direct Propagation of Shallow Ensembles.
110 A shallow ensemble neural network where only the last layer differs across
111 ensemble members. Provides calibrated uncertainty estimates through CRPS or
112 NLL training.
114 The last element of `layers` determines the ensemble size (n_ensemble).
115 For example, layers=(5, 10, 32) creates:
116 - Input layer: 5 features
117 - Hidden layer: 10 neurons
118 - Output layer: 32 ensemble members
120 Key Features:
121 - CRPS loss (default): Robust, works out-of-the-box
122 - NLL loss: Automatically pre-trains with MSE for robustness
123 - Post-hoc calibration on validation set
124 - Uncertainty propagation through ensemble members
125 """
127 def __init__(
128 self,
129 layers,
130 activation=nn.relu,
131 seed=19,
132 loss_type="crps",
133 min_sigma=1e-3,
134 optimizer="bfgs",
135 ):
136 """Initialize a DPOSE model.
138 Args:
139 layers: Tuple of integers for neurons in each layer.
140 The last value is the ensemble size (recommended: 16-64).
141 activation: Activation function for hidden layers (default: ReLU).
142 seed: Random seed for weight initialization.
143 loss_type: Loss function - 'crps', 'nll', or 'mse' (default: 'crps').
144 - 'crps': Continuous ranked probability score (Kellner Eq. 18) - RECOMMENDED
145 More robust, prevents uncertainty inflation, works out-of-the-box.
146 - 'nll': Negative log-likelihood (Kellner Eq. 6)
147 Can fail without pre-training or normalization (see WHY_NLL_FAILS.md).
148 - 'mse': Mean squared error (no uncertainty training)
149 min_sigma: Minimum standard deviation for numerical stability (default: 1e-3).
150 Prevents division by zero when ensemble members are nearly identical.
151 optimizer: Optimization algorithm (default: 'bfgs'). Options:
152 - 'bfgs': BFGS (quasi-Newton, recommended for smooth objectives)
153 - 'lbfgs': Limited-memory BFGS (for larger problems)
154 - 'adam': Adam (adaptive learning rate)
155 - 'sgd': Stochastic gradient descent
156 - 'muon': Muon (orthogonalized momentum, state-of-the-art 2024)
157 - 'lbfgsb': L-BFGS-B (with box constraints)
158 - 'nonlinear_cg': Nonlinear conjugate gradient
159 - 'gradient_descent': Basic gradient descent
160 """
161 self.layers = layers
162 self.n_ensemble = layers[-1]
163 self.key = jax.random.PRNGKey(seed)
164 # Handle activation=None by defaulting to ReLU
165 if activation is None:
166 activation = nn.relu
167 self.nn = _NN(layers, activation)
168 self.loss_type = loss_type
169 self.min_sigma = min_sigma
170 self.optimizer = optimizer.lower()
171 self.calibration_factor = 1.0 # Default: no calibration
173 def fit(self, X, y, val_X=None, val_y=None, pretrain_with_mse=None, **kwargs):
174 """Fit the DPOSE model with calibrated uncertainty estimation.
176 Args:
177 X: Training features, shape (n_samples, n_features).
178 y: Training targets, shape (n_samples,).
179 val_X: Optional validation features for post-hoc calibration.
180 val_y: Optional validation targets for post-hoc calibration.
181 pretrain_with_mse: If True, pre-train with MSE before NLL training (default: auto).
182 - For 'nll': defaults to True (robust two-stage training)
183 - For 'crps'/'mse': defaults to False (not needed)
184 Set to False to disable pre-training for NLL (not recommended).
185 **kwargs: Additional arguments passed to the optimizer. Common parameters:
186 - maxiter: Maximum iterations (default: 1500)
187 - tol: Convergence tolerance (default: 1e-3)
188 - pretrain_maxiter: Iterations for MSE pre-training (default: 500)
190 Optimizer-specific parameters:
191 - BFGS/LBFGS: stepsize, linesearch, max_linesearch_iter
192 - Adam: learning_rate (default: 1e-3), b1, b2, eps
193 - SGD: learning_rate, momentum
194 - Muon: learning_rate (default: 0.02), beta (default: 0.95),
195 ns_steps (default: 5), weight_decay
196 - See jaxopt documentation for full parameter lists
198 Returns:
199 self: Fitted model.
200 """
201 # Auto-detect if we should pre-train
202 if pretrain_with_mse is None:
203 pretrain_with_mse = self.loss_type == "nll"
205 # Extract pre-training specific kwargs
206 pretrain_maxiter = kwargs.pop("pretrain_maxiter", 500)
208 # Stage 1: MSE pre-training (if using NLL and pretrain enabled)
209 if pretrain_with_mse and self.loss_type == "nll":
210 print("\n" + "=" * 70)
211 print("NLL TRAINING: Two-Stage Approach for Robustness")
212 print("=" * 70)
213 print(f"Stage 1: MSE pre-training ({pretrain_maxiter} iterations)")
214 print(" → Ensures good predictions before uncertainty calibration")
216 # Temporarily switch to MSE
217 original_loss = self.loss_type
218 self.loss_type = "mse"
220 # Create kwargs for pre-training
221 pretrain_kwargs = kwargs.copy()
222 pretrain_kwargs["maxiter"] = pretrain_maxiter
224 # Pre-train with MSE
225 self._fit_internal(X, y, val_X=None, val_y=None, **pretrain_kwargs)
227 # Report pre-training results
228 y_pred_pretrain = self.predict(X)
229 mae_pretrain = np.mean(np.abs(y - y_pred_pretrain))
230 print(f" ✓ Pre-training complete: MAE = {mae_pretrain:.6f}")
232 # Restore NLL
233 self.loss_type = original_loss
235 print("\nStage 2: NLL fine-tuning (uncertainty calibration)")
236 print(" → Calibrating uncertainties while maintaining accuracy")
237 print("=" * 70 + "\n")
239 # Stage 2: Main training (NLL, CRPS, or MSE)
240 return self._fit_internal(X, y, val_X, val_y, **kwargs)
242 def _fit_internal(self, X, y, val_X=None, val_y=None, **kwargs):
243 """Internal method for actual fitting (used by fit() for each stage).
245 This is separated out to enable two-stage training for NLL.
246 """
247 # Initialize or reuse parameters
248 if not hasattr(self, "optpars"):
249 params = self.nn.init(self.key, X) # Dummy input to init
250 else:
251 params = self.optpars
253 @jit
254 def objective(pars):
255 """Loss function with per-sample uncertainty from ensemble spread."""
256 # Get ensemble predictions: shape (n_samples, n_ensemble)
257 pY = self.nn.apply(pars, np.asarray(X))
259 # Ensemble statistics
260 py = pY.mean(axis=1) # Predicted mean (n_samples,)
261 # Uncertainty with numerically stable gradient (avoids NaN when ensemble members are identical)
262 sigma = np.sqrt(
263 pY.var(axis=1) + self.min_sigma**2
264 ) # Predicted uncertainty (n_samples,)
266 # Prediction errors
267 errs = np.asarray(y).ravel() - py
269 if self.loss_type == "nll":
270 # Negative Log-Likelihood (Kellner & Ceriotti, Eq. 6)
271 # Penalizes both prediction errors AND miscalibrated uncertainties
272 nll = 0.5 * (errs**2 / sigma**2 + np.log(2 * np.pi * sigma**2))
273 return np.mean(nll)
275 elif self.loss_type == "crps":
276 # Continuous Ranked Probability Score (Kellner & Ceriotti, Eq. 18)
277 # More robust than NLL, less sensitive to outliers
278 z = errs / sigma
279 phi_z = jax.scipy.stats.norm.pdf(z)
280 Phi_z = jax.scipy.stats.norm.cdf(z)
281 crps = sigma * (z * (2 * Phi_z - 1) + 2 * phi_z - 1 / np.sqrt(np.pi))
282 return np.mean(crps)
284 elif self.loss_type == "mse":
285 # Simple MSE (no uncertainty training)
286 return np.mean(errs**2)
288 else:
289 raise ValueError(
290 f"Unknown loss_type: {self.loss_type}. Use 'nll', 'crps', or 'mse'."
291 )
293 # Solver configuration
294 if "maxiter" not in kwargs:
295 kwargs["maxiter"] = 1500
296 if "tol" not in kwargs:
297 kwargs["tol"] = 1e-3
299 # Select optimizer
300 if self.optimizer == "bfgs":
301 solver = jaxopt.BFGS(fun=value_and_grad(objective), value_and_grad=True, **kwargs)
302 elif self.optimizer == "lbfgs":
303 solver = jaxopt.LBFGS(fun=value_and_grad(objective), value_and_grad=True, **kwargs)
304 elif self.optimizer == "lbfgsb":
305 solver = jaxopt.LBFGSB(fun=value_and_grad(objective), value_and_grad=True, **kwargs)
306 elif self.optimizer == "nonlinear_cg":
307 solver = jaxopt.NonlinearCG(
308 fun=value_and_grad(objective), value_and_grad=True, **kwargs
309 )
310 elif self.optimizer == "adam":
311 # Adam uses OptaxSolver with optax optimizer
312 if "learning_rate" not in kwargs:
313 kwargs["learning_rate"] = 1e-3
314 solver = jaxopt.OptaxSolver(
315 opt=optax.adam(kwargs.pop("learning_rate")), fun=objective, **kwargs
316 )
317 elif self.optimizer == "sgd":
318 # SGD uses OptaxSolver with optax optimizer
319 if "learning_rate" not in kwargs:
320 kwargs["learning_rate"] = 1e-2
321 lr = kwargs.pop("learning_rate")
322 momentum = kwargs.pop("momentum", 0.9)
323 solver = jaxopt.OptaxSolver(
324 opt=optax.sgd(lr, momentum=momentum), fun=objective, **kwargs
325 )
326 elif self.optimizer == "muon":
327 # Muon uses OptaxSolver with optax.contrib.muon
328 # Muon orthogonalizes momentum updates for 2D parameters
329 if "learning_rate" not in kwargs:
330 kwargs["learning_rate"] = 0.02 # Muon typically uses higher LR than Adam
331 lr = kwargs.pop("learning_rate")
332 beta = kwargs.pop("beta", 0.95)
333 ns_steps = kwargs.pop("ns_steps", 5)
334 weight_decay = kwargs.pop("weight_decay", 0.0)
336 solver = jaxopt.OptaxSolver(
337 opt=optax.contrib.muon(
338 learning_rate=lr,
339 beta=beta,
340 ns_steps=ns_steps,
341 nesterov=True,
342 weight_decay=weight_decay,
343 ),
344 fun=objective,
345 **kwargs,
346 )
347 elif self.optimizer == "gradient_descent":
348 solver = jaxopt.GradientDescent(fun=objective, **kwargs)
349 else:
350 raise ValueError(
351 f"Unknown optimizer: {self.optimizer}. "
352 f"Choose from: bfgs, lbfgs, lbfgsb, nonlinear_cg, adam, sgd, muon, gradient_descent"
353 )
355 # Optimize
356 self.optpars, self.state = solver.run(params)
358 # Post-hoc calibration on validation set if provided
359 if val_X is not None and val_y is not None:
360 self._calibrate(val_X, val_y)
362 return self
364 def _calibrate(self, X, y):
365 """Apply post-hoc calibration using validation set.
367 Implements Eq. 8 from Kellner & Ceriotti (2024):
368 α² = (1/n_val) Σ [Δy(X)² / σ(X)²]
370 This rescales uncertainties so that their magnitude matches
371 actual prediction errors on the validation set.
373 Args:
374 X: Validation features.
375 y: Validation targets.
376 """
377 pY = self.nn.apply(self.optpars, np.asarray(X))
378 py = pY.mean(axis=1)
379 sigma = np.sqrt(pY.var(axis=1) + self.min_sigma**2)
381 errs = np.asarray(y).ravel() - py
383 # Check for ensemble collapse
384 mean_sigma = np.mean(sigma)
385 if mean_sigma < 1e-8:
386 print("\n⚠ WARNING: Ensemble has collapsed!")
387 print(f" Mean uncertainty: {mean_sigma:.2e} (nearly zero)")
388 print(f" Ensemble spread: {sigma.min():.2e} to {sigma.max():.2e}")
389 print("\n Possible causes:")
390 print(f" - Ensemble size too small (current: {self.n_ensemble})")
391 print(" - Training with MSE loss (use 'nll' or 'crps')")
392 print(" - Model overfit (reduce training iterations)")
393 print("\n Skipping calibration (using α = 1.0)")
394 self.calibration_factor = 1.0
395 return
397 # Calibration factor: ratio of empirical to predicted variance
398 alpha_sq = np.mean(errs**2) / np.mean(sigma**2)
399 self.calibration_factor = float(np.sqrt(alpha_sq))
401 # Check for numerical issues
402 if not np.isfinite(self.calibration_factor):
403 print(f"\n⚠ WARNING: Calibration failed (α = {self.calibration_factor})")
404 print(f" Mean error²: {np.mean(errs**2):.6f}")
405 print(f" Mean σ²: {np.mean(sigma**2):.6f}")
406 print(" Skipping calibration (using α = 1.0)")
407 self.calibration_factor = 1.0
408 return
410 print(f"\nCalibration factor α = {self.calibration_factor:.4f}")
411 if self.calibration_factor > 1.5:
412 print(" ⚠ Model is overconfident (α > 1.5)")
413 elif self.calibration_factor < 0.7:
414 print(" ⚠ Model is underconfident (α < 0.7)")
415 else:
416 print(" ✓ Model is well-calibrated")
418 def report(self):
419 """Print optimization diagnostics."""
420 print("Optimization converged:")
421 # Handle different state formats from different optimizers
422 if hasattr(self.state, "iter_num"):
423 print(f" Iterations: {self.state.iter_num}")
424 elif hasattr(self.state, "num_iter"):
425 print(f" Iterations: {self.state.num_iter}")
427 if hasattr(self.state, "value"):
428 print(f" Final loss: {self.state.value:.6f}")
430 print(f" Optimizer: {self.optimizer}")
431 print(f" Ensemble size: {self.n_ensemble}")
432 print(f" Loss type: {self.loss_type}")
433 if hasattr(self, "calibration_factor"):
434 print(f" Calibration: α = {self.calibration_factor:.4f}")
436 def predict(self, X, return_std=False):
437 """Make predictions with uncertainty estimates.
439 Args:
440 X: Input features, shape (n_samples, n_features).
441 return_std: If True, return (predictions, uncertainties).
443 Returns:
444 predictions: Mean ensemble predictions, shape (n_samples,).
445 uncertainties: Standard deviation (if return_std=True), shape (n_samples,).
446 """
447 X = np.atleast_2d(X)
448 P = self.nn.apply(self.optpars, X)
450 mean_pred = P.mean(axis=1)
451 std_pred = np.sqrt(P.var(axis=1) + self.min_sigma**2)
453 # Apply post-hoc calibration if available
454 if hasattr(self, "calibration_factor") and self.calibration_factor != 1.0:
455 std_pred = std_pred * self.calibration_factor
457 if return_std:
458 return mean_pred, std_pred
459 else:
460 return mean_pred
462 def predict_ensemble(self, X):
463 """Get full ensemble predictions for uncertainty propagation.
465 This method is crucial for propagating uncertainties through
466 non-linear transformations (Kellner & Ceriotti, Eq. 11).
468 Example:
469 # For some function f(y)
470 ensemble_preds = model.predict_ensemble(X) # (n_samples, n_ensemble)
471 z_ensemble = f(ensemble_preds) # Apply f to each member
472 z_mean = z_ensemble.mean(axis=1) # Mean of transformed quantity
473 z_std = z_ensemble.std(axis=1) # Uncertainty of transformed quantity
475 Args:
476 X: Input features, shape (n_samples, n_features).
478 Returns:
479 ensemble_predictions: Full ensemble output, shape (n_samples, n_ensemble).
480 """
481 X = np.atleast_2d(X)
482 return self.nn.apply(self.optpars, X)
484 def __call__(self, X, return_std=False, distribution=False):
485 """Execute the model (alternative interface to predict).
487 Args:
488 X: Input features, shape (n_samples, n_features).
489 return_std: If True, return uncertainties.
490 distribution: If True, return full ensemble; else return mean.
492 Returns:
493 If distribution=False: predictions (and uncertainties if return_std=True).
494 If distribution=True: full ensemble predictions, shape (n_samples, n_ensemble).
495 """
496 if not hasattr(self, "optpars"):
497 raise Exception("You need to fit the model first.")
499 X = np.atleast_2d(X)
500 P = self.nn.apply(self.optpars, X)
502 if distribution:
503 # Return full ensemble
504 if return_std:
505 se = np.sqrt(P.var(axis=1) + self.min_sigma**2)
506 if hasattr(self, "calibration_factor") and self.calibration_factor != 1.0:
507 se = se * self.calibration_factor
508 return (P, se)
509 else:
510 return P
511 else:
512 # Return mean (and std if requested)
513 mean_pred = P.mean(axis=1)
514 if return_std:
515 std_pred = np.sqrt(P.var(axis=1) + self.min_sigma**2)
516 if hasattr(self, "calibration_factor") and self.calibration_factor != 1.0:
517 std_pred = std_pred * self.calibration_factor
518 return (mean_pred, std_pred)
519 else:
520 return mean_pred
522 def plot(self, X, y, distribution=False, ax=None):
523 """Visualize predictions with uncertainty bands.
525 Args:
526 X: Input features, shape (n_samples, n_features).
527 For 1D input, will be used as x-axis.
528 y: True target values.
529 distribution: If True, plot individual ensemble members.
530 ax: Matplotlib axis (optional). If None, uses current axis.
532 Returns:
533 matplotlib figure object.
534 """
535 if ax is None:
536 ax = plt.gca()
538 # Get predictions with calibrated uncertainties
539 mp, se = self.predict(X, return_std=True)
541 # For line plots, need to sort by X
542 X_plot = X.ravel()
543 sort_idx = np.argsort(X_plot)
544 X_sorted = X_plot[sort_idx]
545 mp_sorted = mp[sort_idx]
546 se_sorted = se[sort_idx]
548 # Plot in correct z-order (back to front):
550 # 1. Uncertainty band (background, lowest z-order)
551 ax.fill_between(
552 X_sorted,
553 mp_sorted - 2 * se_sorted,
554 mp_sorted + 2 * se_sorted,
555 alpha=0.3,
556 color="red",
557 label="±2σ (95% CI)",
558 zorder=1,
559 )
561 # 2. Individual ensemble members (if requested, middle layer)
562 if distribution:
563 P = self.nn.apply(self.optpars, X)
564 P_sorted = P[sort_idx, :]
565 # Plot all members at once (more efficient) with very low alpha
566 ax.plot(X_sorted, P_sorted, "k-", alpha=0.05, linewidth=0.5, zorder=2)
568 # 3. Mean prediction line (middle-front, should be visible)
569 ax.plot(X_sorted, mp_sorted, "r-", label="mean prediction", linewidth=2.5, zorder=3)
571 # 4. Data points (front, highest z-order so always visible)
572 ax.plot(X_plot, y, "b.", label="data", alpha=0.7, markersize=8, zorder=4)
574 ax.set_xlabel("X")
575 ax.set_ylabel("y")
576 ax.legend()
577 ax.set_title(f"DPOSE Predictions (n_ensemble={self.n_ensemble})")
578 ax.grid(True, alpha=0.3)
580 return plt.gcf()
582 def uncertainty_metrics(self, X, y):
583 """Compute uncertainty quantification metrics.
585 Following Kellner & Ceriotti (2024), computes several diagnostics
586 to assess the quality of uncertainty estimates.
588 Args:
589 X: Input features.
590 y: True target values.
592 Returns:
593 dict with keys:
594 - 'rmse': Root mean squared error
595 - 'mae': Mean absolute error
596 - 'nll': Negative log-likelihood (lower is better)
597 - 'miscalibration_area': Deviation from ideal calibration (lower is better)
598 - 'z_score_mean': Should be ~0 if well-calibrated
599 - 'z_score_std': Should be ~1 if well-calibrated
600 """
601 mp, se = self.predict(X, return_std=True)
602 y = np.asarray(y).ravel()
604 errs = y - mp
605 rmse = np.sqrt(np.mean(errs**2))
606 mae = np.mean(np.abs(errs))
608 # Check for ensemble collapse (sigma too small)
609 mean_se = np.mean(se)
610 if mean_se < 1e-8:
611 print("\n⚠ WARNING: Cannot compute uncertainty metrics - ensemble has collapsed!")
612 print(f" Mean uncertainty: {mean_se:.2e} (nearly zero)")
613 print(" This causes division by zero in metric calculations.")
614 print("\n Returning basic accuracy metrics only (NLL, Z-scores unavailable)")
616 return {
617 "rmse": float(rmse),
618 "mae": float(mae),
619 "nll": float("nan"),
620 "miscalibration_area": float("nan"),
621 "z_score_mean": float("nan"),
622 "z_score_std": float("nan"),
623 }
625 # Check for any numerical issues in uncertainties
626 if not np.all(np.isfinite(se)) or np.any(se <= 0):
627 print("\n⚠ WARNING: Invalid uncertainty values detected!")
628 print(f" Contains NaN: {np.any(np.isnan(se))}")
629 print(f" Contains inf: {np.any(np.isinf(se))}")
630 print(f" Contains zeros or negatives: {np.any(se <= 0)}")
631 print("\n Returning basic accuracy metrics only")
633 return {
634 "rmse": float(rmse),
635 "mae": float(mae),
636 "nll": float("nan"),
637 "miscalibration_area": float("nan"),
638 "z_score_mean": float("nan"),
639 "z_score_std": float("nan"),
640 }
642 # NLL (Eq. 6)
643 nll = 0.5 * np.mean(errs**2 / se**2 + np.log(2 * np.pi * se**2))
645 # Standardized residuals (z-scores)
646 z_scores = errs / se
647 z_mean = np.mean(z_scores)
648 z_std = np.std(z_scores)
650 # Miscalibration area (Kellner Fig. 2)
651 # Measures deviation of empirical CDF from theoretical Gaussian CDF
652 sorted_z = np.sort(z_scores)
653 empirical_cdf = np.arange(1, len(sorted_z) + 1) / len(sorted_z)
654 theoretical_cdf = jax.scipy.stats.norm.cdf(sorted_z)
655 miscalibration_area = np.mean(np.abs(empirical_cdf - theoretical_cdf))
657 metrics = {
658 "rmse": float(rmse),
659 "mae": float(mae),
660 "nll": float(nll),
661 "miscalibration_area": float(miscalibration_area),
662 "z_score_mean": float(z_mean),
663 "z_score_std": float(z_std),
664 }
666 return metrics
668 def print_metrics(self, X, y):
669 """Print uncertainty metrics in human-readable format.
671 Args:
672 X: Input features.
673 y: True target values.
674 """
675 metrics = self.uncertainty_metrics(X, y)
677 print("\n" + "=" * 50)
678 print("UNCERTAINTY QUANTIFICATION METRICS")
679 print("=" * 50)
680 print("Prediction Accuracy:")
681 print(f" RMSE: {metrics['rmse']:.6f}")
682 print(f" MAE: {metrics['mae']:.6f}")
684 # Check if uncertainty metrics are available
685 has_uncertainty = not np.isnan(metrics["nll"])
687 if has_uncertainty:
688 print("\nUncertainty Quality:")
689 print(f" NLL: {metrics['nll']:.6f} (lower is better)")
690 print(f" Miscalibration Area: {metrics['miscalibration_area']:.6f} (lower is better)")
691 print("\nCalibration Diagnostics:")
692 print(f" Z-score mean: {metrics['z_score_mean']:.4f} (ideal: 0)")
693 print(f" Z-score std: {metrics['z_score_std']:.4f} (ideal: 1)")
695 # Interpret calibration
696 if abs(metrics["z_score_mean"]) < 0.1 and abs(metrics["z_score_std"] - 1) < 0.2:
697 print(" ✓ Well-calibrated uncertainties")
698 elif metrics["z_score_std"] < 0.8:
699 print(" ⚠ Overconfident (uncertainties too small)")
700 elif metrics["z_score_std"] > 1.2:
701 print(" ⚠ Underconfident (uncertainties too large)")
702 else:
703 print(" ⚠ Miscalibrated")
704 else:
705 print("\nUncertainty Quality:")
706 print(" NLL: N/A (ensemble collapsed)")
707 print(" Miscalibration Area: N/A")
708 print("\nCalibration Diagnostics:")
709 print(" Z-score mean: N/A")
710 print(" Z-score std: N/A")
711 print("\n ✗ Uncertainty estimates not available due to ensemble collapse")
712 print(" ➜ See warnings above for diagnostic information and suggested fixes")
714 print("=" * 50 + "\n")