Coverage for src/pycse/sklearn/dpose.py: 77.70%

278 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 16:23 -0400

1"""A DPOSE (Direct Propagation of Shallow Ensembles) Neural Network model in JAX. 

2 

3Implementation based on: 

4Kellner, M., & Ceriotti, M. (2024). Uncertainty quantification by direct 

5propagation of shallow ensembles. Machine Learning: Science and Technology, 5(3), 035006. 

6 

7Key features: 

8- Shallow ensemble architecture (only last layer differs across members) 

9- CRPS loss for robust, calibrated uncertainty estimates (default) 

10- Alternative NLL or MSE losses available 

11- Post-hoc calibration on validation set 

12- Ensemble propagation for derived quantities 

13 

14Example usage: 

15 

16 import jax 

17 import numpy as np 

18 import matplotlib.pyplot as plt 

19 

20 # Generate heteroscedastic data 

21 key = jax.random.PRNGKey(19) 

22 x = np.linspace(0, 1, 100)[:, None] 

23 noise_level = 0.01 + 0.1 * x.ravel() # Increasing noise 

24 y = x.ravel()**(1/3) + noise_level * jax.random.normal(key, (100,)) 

25 

26 # Split into train/validation 

27 from sklearn.model_selection import train_test_split 

28 x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42) 

29 

30 # Train with DPOSE (uses CRPS loss and BFGS optimizer by default) 

31 from pycse.sklearn.dpose import DPOSE 

32 model = DPOSE(layers=(1, 15, 32)) 

33 model.fit(x_train, y_train, val_X=x_val, val_y=y_val) 

34 

35 # Or use a different optimizer (e.g., Adam) 

36 model_adam = DPOSE(layers=(1, 15, 32), optimizer='adam') 

37 model_adam.fit(x_train, y_train, val_X=x_val, val_y=y_val, learning_rate=1e-3) 

38 

39 # Or use Muon optimizer (state-of-the-art 2024) 

40 model_muon = DPOSE(layers=(1, 15, 32), optimizer='muon') 

41 model_muon.fit(x_train, y_train, val_X=x_val, val_y=y_val, learning_rate=0.02) 

42 

43 # Get predictions with uncertainty 

44 y_pred, y_std = model.predict(x, return_std=True) 

45 

46 # Visualize 

47 model.plot(x, y, distribution=True) 

48 

49 # For uncertainty propagation on derived quantities 

50 ensemble_preds = model.predict_ensemble(x) # (n_samples, n_ensemble) 

51 # Apply any function f to ensemble members 

52 z_ensemble = f(ensemble_preds) 

53 z_mean = z_ensemble.mean(axis=1) 

54 z_std = z_ensemble.std(axis=1) 

55 

56Requires: flax, jaxopt, jax, scikit-learn 

57""" 

58 

59import os 

60import jax 

61 

62 

63from jax import jit 

64import jax.numpy as np 

65from jax import value_and_grad 

66import jaxopt 

67import optax 

68import matplotlib.pyplot as plt 

69from sklearn.base import BaseEstimator, RegressorMixin 

70from flax import linen as nn 

71from flax.linen.initializers import xavier_uniform 

72 

73os.environ["JAX_ENABLE_X64"] = "True" 

74jax.config.update("jax_enable_x64", True) 

75 

76 

77class _NN(nn.Module): 

78 """A flax neural network. 

79 

80 layers: a Tuple of integers specifying the network architecture. 

81 - layers[0]: Input dimension (number of features) 

82 - layers[1:-1]: Hidden layer sizes 

83 - layers[-1]: Output dimension (ensemble size) 

84 

85 Example: layers=(5, 20, 32) creates: 

86 - Input: 5 features 

87 - Hidden: 20 neurons with activation 

88 - Output: 32 ensemble members (no activation) 

89 """ 

90 

91 layers: tuple 

92 activation: callable 

93 

94 @nn.compact 

95 def __call__(self, x): 

96 # Hidden layers (skip first element which is input dimension) 

97 for i in self.layers[1:-1]: 

98 x = nn.Dense(i, kernel_init=xavier_uniform())(x) 

99 x = self.activation(x) 

100 

101 # Linear last layer where each row is a set of predictions 

102 # The mean on axis=1 is the prediction 

103 x = nn.Dense(self.layers[-1])(x) 

104 return x 

105 

106 

107class DPOSE(BaseEstimator, RegressorMixin): 

108 """DPOSE: Direct Propagation of Shallow Ensembles. 

109 

110 A shallow ensemble neural network where only the last layer differs across 

111 ensemble members. Provides calibrated uncertainty estimates through CRPS or 

112 NLL training. 

113 

114 The last element of `layers` determines the ensemble size (n_ensemble). 

115 For example, layers=(5, 10, 32) creates: 

116 - Input layer: 5 features 

117 - Hidden layer: 10 neurons 

118 - Output layer: 32 ensemble members 

119 

120 Key Features: 

121 - CRPS loss (default): Robust, works out-of-the-box 

122 - NLL loss: Automatically pre-trains with MSE for robustness 

123 - Post-hoc calibration on validation set 

124 - Uncertainty propagation through ensemble members 

125 """ 

126 

127 def __init__( 

128 self, 

129 layers, 

130 activation=nn.relu, 

131 seed=19, 

132 loss_type="crps", 

133 min_sigma=1e-3, 

134 optimizer="bfgs", 

135 ): 

136 """Initialize a DPOSE model. 

137 

138 Args: 

139 layers: Tuple of integers for neurons in each layer. 

140 The last value is the ensemble size (recommended: 16-64). 

141 activation: Activation function for hidden layers (default: ReLU). 

142 seed: Random seed for weight initialization. 

143 loss_type: Loss function - 'crps', 'nll', or 'mse' (default: 'crps'). 

144 - 'crps': Continuous ranked probability score (Kellner Eq. 18) - RECOMMENDED 

145 More robust, prevents uncertainty inflation, works out-of-the-box. 

146 - 'nll': Negative log-likelihood (Kellner Eq. 6) 

147 Can fail without pre-training or normalization (see WHY_NLL_FAILS.md). 

148 - 'mse': Mean squared error (no uncertainty training) 

149 min_sigma: Minimum standard deviation for numerical stability (default: 1e-3). 

150 Prevents division by zero when ensemble members are nearly identical. 

151 optimizer: Optimization algorithm (default: 'bfgs'). Options: 

152 - 'bfgs': BFGS (quasi-Newton, recommended for smooth objectives) 

153 - 'lbfgs': Limited-memory BFGS (for larger problems) 

154 - 'adam': Adam (adaptive learning rate) 

155 - 'sgd': Stochastic gradient descent 

156 - 'muon': Muon (orthogonalized momentum, state-of-the-art 2024) 

157 - 'lbfgsb': L-BFGS-B (with box constraints) 

158 - 'nonlinear_cg': Nonlinear conjugate gradient 

159 - 'gradient_descent': Basic gradient descent 

160 """ 

161 self.layers = layers 

162 self.n_ensemble = layers[-1] 

163 self.key = jax.random.PRNGKey(seed) 

164 # Handle activation=None by defaulting to ReLU 

165 if activation is None: 

166 activation = nn.relu 

167 self.nn = _NN(layers, activation) 

168 self.loss_type = loss_type 

169 self.min_sigma = min_sigma 

170 self.optimizer = optimizer.lower() 

171 self.calibration_factor = 1.0 # Default: no calibration 

172 

173 def fit(self, X, y, val_X=None, val_y=None, pretrain_with_mse=None, **kwargs): 

174 """Fit the DPOSE model with calibrated uncertainty estimation. 

175 

176 Args: 

177 X: Training features, shape (n_samples, n_features). 

178 y: Training targets, shape (n_samples,). 

179 val_X: Optional validation features for post-hoc calibration. 

180 val_y: Optional validation targets for post-hoc calibration. 

181 pretrain_with_mse: If True, pre-train with MSE before NLL training (default: auto). 

182 - For 'nll': defaults to True (robust two-stage training) 

183 - For 'crps'/'mse': defaults to False (not needed) 

184 Set to False to disable pre-training for NLL (not recommended). 

185 **kwargs: Additional arguments passed to the optimizer. Common parameters: 

186 - maxiter: Maximum iterations (default: 1500) 

187 - tol: Convergence tolerance (default: 1e-3) 

188 - pretrain_maxiter: Iterations for MSE pre-training (default: 500) 

189 

190 Optimizer-specific parameters: 

191 - BFGS/LBFGS: stepsize, linesearch, max_linesearch_iter 

192 - Adam: learning_rate (default: 1e-3), b1, b2, eps 

193 - SGD: learning_rate, momentum 

194 - Muon: learning_rate (default: 0.02), beta (default: 0.95), 

195 ns_steps (default: 5), weight_decay 

196 - See jaxopt documentation for full parameter lists 

197 

198 Returns: 

199 self: Fitted model. 

200 """ 

201 # Auto-detect if we should pre-train 

202 if pretrain_with_mse is None: 

203 pretrain_with_mse = self.loss_type == "nll" 

204 

205 # Extract pre-training specific kwargs 

206 pretrain_maxiter = kwargs.pop("pretrain_maxiter", 500) 

207 

208 # Stage 1: MSE pre-training (if using NLL and pretrain enabled) 

209 if pretrain_with_mse and self.loss_type == "nll": 

210 print("\n" + "=" * 70) 

211 print("NLL TRAINING: Two-Stage Approach for Robustness") 

212 print("=" * 70) 

213 print(f"Stage 1: MSE pre-training ({pretrain_maxiter} iterations)") 

214 print(" → Ensures good predictions before uncertainty calibration") 

215 

216 # Temporarily switch to MSE 

217 original_loss = self.loss_type 

218 self.loss_type = "mse" 

219 

220 # Create kwargs for pre-training 

221 pretrain_kwargs = kwargs.copy() 

222 pretrain_kwargs["maxiter"] = pretrain_maxiter 

223 

224 # Pre-train with MSE 

225 self._fit_internal(X, y, val_X=None, val_y=None, **pretrain_kwargs) 

226 

227 # Report pre-training results 

228 y_pred_pretrain = self.predict(X) 

229 mae_pretrain = np.mean(np.abs(y - y_pred_pretrain)) 

230 print(f" ✓ Pre-training complete: MAE = {mae_pretrain:.6f}") 

231 

232 # Restore NLL 

233 self.loss_type = original_loss 

234 

235 print("\nStage 2: NLL fine-tuning (uncertainty calibration)") 

236 print(" → Calibrating uncertainties while maintaining accuracy") 

237 print("=" * 70 + "\n") 

238 

239 # Stage 2: Main training (NLL, CRPS, or MSE) 

240 return self._fit_internal(X, y, val_X, val_y, **kwargs) 

241 

242 def _fit_internal(self, X, y, val_X=None, val_y=None, **kwargs): 

243 """Internal method for actual fitting (used by fit() for each stage). 

244 

245 This is separated out to enable two-stage training for NLL. 

246 """ 

247 # Initialize or reuse parameters 

248 if not hasattr(self, "optpars"): 

249 params = self.nn.init(self.key, X) # Dummy input to init 

250 else: 

251 params = self.optpars 

252 

253 @jit 

254 def objective(pars): 

255 """Loss function with per-sample uncertainty from ensemble spread.""" 

256 # Get ensemble predictions: shape (n_samples, n_ensemble) 

257 pY = self.nn.apply(pars, np.asarray(X)) 

258 

259 # Ensemble statistics 

260 py = pY.mean(axis=1) # Predicted mean (n_samples,) 

261 # Uncertainty with numerically stable gradient (avoids NaN when ensemble members are identical) 

262 sigma = np.sqrt( 

263 pY.var(axis=1) + self.min_sigma**2 

264 ) # Predicted uncertainty (n_samples,) 

265 

266 # Prediction errors 

267 errs = np.asarray(y).ravel() - py 

268 

269 if self.loss_type == "nll": 

270 # Negative Log-Likelihood (Kellner & Ceriotti, Eq. 6) 

271 # Penalizes both prediction errors AND miscalibrated uncertainties 

272 nll = 0.5 * (errs**2 / sigma**2 + np.log(2 * np.pi * sigma**2)) 

273 return np.mean(nll) 

274 

275 elif self.loss_type == "crps": 

276 # Continuous Ranked Probability Score (Kellner & Ceriotti, Eq. 18) 

277 # More robust than NLL, less sensitive to outliers 

278 z = errs / sigma 

279 phi_z = jax.scipy.stats.norm.pdf(z) 

280 Phi_z = jax.scipy.stats.norm.cdf(z) 

281 crps = sigma * (z * (2 * Phi_z - 1) + 2 * phi_z - 1 / np.sqrt(np.pi)) 

282 return np.mean(crps) 

283 

284 elif self.loss_type == "mse": 

285 # Simple MSE (no uncertainty training) 

286 return np.mean(errs**2) 

287 

288 else: 

289 raise ValueError( 

290 f"Unknown loss_type: {self.loss_type}. Use 'nll', 'crps', or 'mse'." 

291 ) 

292 

293 # Solver configuration 

294 if "maxiter" not in kwargs: 

295 kwargs["maxiter"] = 1500 

296 if "tol" not in kwargs: 

297 kwargs["tol"] = 1e-3 

298 

299 # Select optimizer 

300 if self.optimizer == "bfgs": 

301 solver = jaxopt.BFGS(fun=value_and_grad(objective), value_and_grad=True, **kwargs) 

302 elif self.optimizer == "lbfgs": 

303 solver = jaxopt.LBFGS(fun=value_and_grad(objective), value_and_grad=True, **kwargs) 

304 elif self.optimizer == "lbfgsb": 

305 solver = jaxopt.LBFGSB(fun=value_and_grad(objective), value_and_grad=True, **kwargs) 

306 elif self.optimizer == "nonlinear_cg": 

307 solver = jaxopt.NonlinearCG( 

308 fun=value_and_grad(objective), value_and_grad=True, **kwargs 

309 ) 

310 elif self.optimizer == "adam": 

311 # Adam uses OptaxSolver with optax optimizer 

312 if "learning_rate" not in kwargs: 

313 kwargs["learning_rate"] = 1e-3 

314 solver = jaxopt.OptaxSolver( 

315 opt=optax.adam(kwargs.pop("learning_rate")), fun=objective, **kwargs 

316 ) 

317 elif self.optimizer == "sgd": 

318 # SGD uses OptaxSolver with optax optimizer 

319 if "learning_rate" not in kwargs: 

320 kwargs["learning_rate"] = 1e-2 

321 lr = kwargs.pop("learning_rate") 

322 momentum = kwargs.pop("momentum", 0.9) 

323 solver = jaxopt.OptaxSolver( 

324 opt=optax.sgd(lr, momentum=momentum), fun=objective, **kwargs 

325 ) 

326 elif self.optimizer == "muon": 

327 # Muon uses OptaxSolver with optax.contrib.muon 

328 # Muon orthogonalizes momentum updates for 2D parameters 

329 if "learning_rate" not in kwargs: 

330 kwargs["learning_rate"] = 0.02 # Muon typically uses higher LR than Adam 

331 lr = kwargs.pop("learning_rate") 

332 beta = kwargs.pop("beta", 0.95) 

333 ns_steps = kwargs.pop("ns_steps", 5) 

334 weight_decay = kwargs.pop("weight_decay", 0.0) 

335 

336 solver = jaxopt.OptaxSolver( 

337 opt=optax.contrib.muon( 

338 learning_rate=lr, 

339 beta=beta, 

340 ns_steps=ns_steps, 

341 nesterov=True, 

342 weight_decay=weight_decay, 

343 ), 

344 fun=objective, 

345 **kwargs, 

346 ) 

347 elif self.optimizer == "gradient_descent": 

348 solver = jaxopt.GradientDescent(fun=objective, **kwargs) 

349 else: 

350 raise ValueError( 

351 f"Unknown optimizer: {self.optimizer}. " 

352 f"Choose from: bfgs, lbfgs, lbfgsb, nonlinear_cg, adam, sgd, muon, gradient_descent" 

353 ) 

354 

355 # Optimize 

356 self.optpars, self.state = solver.run(params) 

357 

358 # Post-hoc calibration on validation set if provided 

359 if val_X is not None and val_y is not None: 

360 self._calibrate(val_X, val_y) 

361 

362 return self 

363 

364 def _calibrate(self, X, y): 

365 """Apply post-hoc calibration using validation set. 

366 

367 Implements Eq. 8 from Kellner & Ceriotti (2024): 

368 α² = (1/n_val) Σ [Δy(X)² / σ(X)²] 

369 

370 This rescales uncertainties so that their magnitude matches 

371 actual prediction errors on the validation set. 

372 

373 Args: 

374 X: Validation features. 

375 y: Validation targets. 

376 """ 

377 pY = self.nn.apply(self.optpars, np.asarray(X)) 

378 py = pY.mean(axis=1) 

379 sigma = np.sqrt(pY.var(axis=1) + self.min_sigma**2) 

380 

381 errs = np.asarray(y).ravel() - py 

382 

383 # Check for ensemble collapse 

384 mean_sigma = np.mean(sigma) 

385 if mean_sigma < 1e-8: 

386 print("\n⚠ WARNING: Ensemble has collapsed!") 

387 print(f" Mean uncertainty: {mean_sigma:.2e} (nearly zero)") 

388 print(f" Ensemble spread: {sigma.min():.2e} to {sigma.max():.2e}") 

389 print("\n Possible causes:") 

390 print(f" - Ensemble size too small (current: {self.n_ensemble})") 

391 print(" - Training with MSE loss (use 'nll' or 'crps')") 

392 print(" - Model overfit (reduce training iterations)") 

393 print("\n Skipping calibration (using α = 1.0)") 

394 self.calibration_factor = 1.0 

395 return 

396 

397 # Calibration factor: ratio of empirical to predicted variance 

398 alpha_sq = np.mean(errs**2) / np.mean(sigma**2) 

399 self.calibration_factor = float(np.sqrt(alpha_sq)) 

400 

401 # Check for numerical issues 

402 if not np.isfinite(self.calibration_factor): 

403 print(f"\n⚠ WARNING: Calibration failed (α = {self.calibration_factor})") 

404 print(f" Mean error²: {np.mean(errs**2):.6f}") 

405 print(f" Mean σ²: {np.mean(sigma**2):.6f}") 

406 print(" Skipping calibration (using α = 1.0)") 

407 self.calibration_factor = 1.0 

408 return 

409 

410 print(f"\nCalibration factor α = {self.calibration_factor:.4f}") 

411 if self.calibration_factor > 1.5: 

412 print(" ⚠ Model is overconfident (α > 1.5)") 

413 elif self.calibration_factor < 0.7: 

414 print(" ⚠ Model is underconfident (α < 0.7)") 

415 else: 

416 print(" ✓ Model is well-calibrated") 

417 

418 def report(self): 

419 """Print optimization diagnostics.""" 

420 print("Optimization converged:") 

421 # Handle different state formats from different optimizers 

422 if hasattr(self.state, "iter_num"): 

423 print(f" Iterations: {self.state.iter_num}") 

424 elif hasattr(self.state, "num_iter"): 

425 print(f" Iterations: {self.state.num_iter}") 

426 

427 if hasattr(self.state, "value"): 

428 print(f" Final loss: {self.state.value:.6f}") 

429 

430 print(f" Optimizer: {self.optimizer}") 

431 print(f" Ensemble size: {self.n_ensemble}") 

432 print(f" Loss type: {self.loss_type}") 

433 if hasattr(self, "calibration_factor"): 

434 print(f" Calibration: α = {self.calibration_factor:.4f}") 

435 

436 def predict(self, X, return_std=False): 

437 """Make predictions with uncertainty estimates. 

438 

439 Args: 

440 X: Input features, shape (n_samples, n_features). 

441 return_std: If True, return (predictions, uncertainties). 

442 

443 Returns: 

444 predictions: Mean ensemble predictions, shape (n_samples,). 

445 uncertainties: Standard deviation (if return_std=True), shape (n_samples,). 

446 """ 

447 X = np.atleast_2d(X) 

448 P = self.nn.apply(self.optpars, X) 

449 

450 mean_pred = P.mean(axis=1) 

451 std_pred = np.sqrt(P.var(axis=1) + self.min_sigma**2) 

452 

453 # Apply post-hoc calibration if available 

454 if hasattr(self, "calibration_factor") and self.calibration_factor != 1.0: 

455 std_pred = std_pred * self.calibration_factor 

456 

457 if return_std: 

458 return mean_pred, std_pred 

459 else: 

460 return mean_pred 

461 

462 def predict_ensemble(self, X): 

463 """Get full ensemble predictions for uncertainty propagation. 

464 

465 This method is crucial for propagating uncertainties through 

466 non-linear transformations (Kellner & Ceriotti, Eq. 11). 

467 

468 Example: 

469 # For some function f(y) 

470 ensemble_preds = model.predict_ensemble(X) # (n_samples, n_ensemble) 

471 z_ensemble = f(ensemble_preds) # Apply f to each member 

472 z_mean = z_ensemble.mean(axis=1) # Mean of transformed quantity 

473 z_std = z_ensemble.std(axis=1) # Uncertainty of transformed quantity 

474 

475 Args: 

476 X: Input features, shape (n_samples, n_features). 

477 

478 Returns: 

479 ensemble_predictions: Full ensemble output, shape (n_samples, n_ensemble). 

480 """ 

481 X = np.atleast_2d(X) 

482 return self.nn.apply(self.optpars, X) 

483 

484 def __call__(self, X, return_std=False, distribution=False): 

485 """Execute the model (alternative interface to predict). 

486 

487 Args: 

488 X: Input features, shape (n_samples, n_features). 

489 return_std: If True, return uncertainties. 

490 distribution: If True, return full ensemble; else return mean. 

491 

492 Returns: 

493 If distribution=False: predictions (and uncertainties if return_std=True). 

494 If distribution=True: full ensemble predictions, shape (n_samples, n_ensemble). 

495 """ 

496 if not hasattr(self, "optpars"): 

497 raise Exception("You need to fit the model first.") 

498 

499 X = np.atleast_2d(X) 

500 P = self.nn.apply(self.optpars, X) 

501 

502 if distribution: 

503 # Return full ensemble 

504 if return_std: 

505 se = np.sqrt(P.var(axis=1) + self.min_sigma**2) 

506 if hasattr(self, "calibration_factor") and self.calibration_factor != 1.0: 

507 se = se * self.calibration_factor 

508 return (P, se) 

509 else: 

510 return P 

511 else: 

512 # Return mean (and std if requested) 

513 mean_pred = P.mean(axis=1) 

514 if return_std: 

515 std_pred = np.sqrt(P.var(axis=1) + self.min_sigma**2) 

516 if hasattr(self, "calibration_factor") and self.calibration_factor != 1.0: 

517 std_pred = std_pred * self.calibration_factor 

518 return (mean_pred, std_pred) 

519 else: 

520 return mean_pred 

521 

522 def plot(self, X, y, distribution=False, ax=None): 

523 """Visualize predictions with uncertainty bands. 

524 

525 Args: 

526 X: Input features, shape (n_samples, n_features). 

527 For 1D input, will be used as x-axis. 

528 y: True target values. 

529 distribution: If True, plot individual ensemble members. 

530 ax: Matplotlib axis (optional). If None, uses current axis. 

531 

532 Returns: 

533 matplotlib figure object. 

534 """ 

535 if ax is None: 

536 ax = plt.gca() 

537 

538 # Get predictions with calibrated uncertainties 

539 mp, se = self.predict(X, return_std=True) 

540 

541 # For line plots, need to sort by X 

542 X_plot = X.ravel() 

543 sort_idx = np.argsort(X_plot) 

544 X_sorted = X_plot[sort_idx] 

545 mp_sorted = mp[sort_idx] 

546 se_sorted = se[sort_idx] 

547 

548 # Plot in correct z-order (back to front): 

549 

550 # 1. Uncertainty band (background, lowest z-order) 

551 ax.fill_between( 

552 X_sorted, 

553 mp_sorted - 2 * se_sorted, 

554 mp_sorted + 2 * se_sorted, 

555 alpha=0.3, 

556 color="red", 

557 label="±2σ (95% CI)", 

558 zorder=1, 

559 ) 

560 

561 # 2. Individual ensemble members (if requested, middle layer) 

562 if distribution: 

563 P = self.nn.apply(self.optpars, X) 

564 P_sorted = P[sort_idx, :] 

565 # Plot all members at once (more efficient) with very low alpha 

566 ax.plot(X_sorted, P_sorted, "k-", alpha=0.05, linewidth=0.5, zorder=2) 

567 

568 # 3. Mean prediction line (middle-front, should be visible) 

569 ax.plot(X_sorted, mp_sorted, "r-", label="mean prediction", linewidth=2.5, zorder=3) 

570 

571 # 4. Data points (front, highest z-order so always visible) 

572 ax.plot(X_plot, y, "b.", label="data", alpha=0.7, markersize=8, zorder=4) 

573 

574 ax.set_xlabel("X") 

575 ax.set_ylabel("y") 

576 ax.legend() 

577 ax.set_title(f"DPOSE Predictions (n_ensemble={self.n_ensemble})") 

578 ax.grid(True, alpha=0.3) 

579 

580 return plt.gcf() 

581 

582 def uncertainty_metrics(self, X, y): 

583 """Compute uncertainty quantification metrics. 

584 

585 Following Kellner & Ceriotti (2024), computes several diagnostics 

586 to assess the quality of uncertainty estimates. 

587 

588 Args: 

589 X: Input features. 

590 y: True target values. 

591 

592 Returns: 

593 dict with keys: 

594 - 'rmse': Root mean squared error 

595 - 'mae': Mean absolute error 

596 - 'nll': Negative log-likelihood (lower is better) 

597 - 'miscalibration_area': Deviation from ideal calibration (lower is better) 

598 - 'z_score_mean': Should be ~0 if well-calibrated 

599 - 'z_score_std': Should be ~1 if well-calibrated 

600 """ 

601 mp, se = self.predict(X, return_std=True) 

602 y = np.asarray(y).ravel() 

603 

604 errs = y - mp 

605 rmse = np.sqrt(np.mean(errs**2)) 

606 mae = np.mean(np.abs(errs)) 

607 

608 # Check for ensemble collapse (sigma too small) 

609 mean_se = np.mean(se) 

610 if mean_se < 1e-8: 

611 print("\n⚠ WARNING: Cannot compute uncertainty metrics - ensemble has collapsed!") 

612 print(f" Mean uncertainty: {mean_se:.2e} (nearly zero)") 

613 print(" This causes division by zero in metric calculations.") 

614 print("\n Returning basic accuracy metrics only (NLL, Z-scores unavailable)") 

615 

616 return { 

617 "rmse": float(rmse), 

618 "mae": float(mae), 

619 "nll": float("nan"), 

620 "miscalibration_area": float("nan"), 

621 "z_score_mean": float("nan"), 

622 "z_score_std": float("nan"), 

623 } 

624 

625 # Check for any numerical issues in uncertainties 

626 if not np.all(np.isfinite(se)) or np.any(se <= 0): 

627 print("\n⚠ WARNING: Invalid uncertainty values detected!") 

628 print(f" Contains NaN: {np.any(np.isnan(se))}") 

629 print(f" Contains inf: {np.any(np.isinf(se))}") 

630 print(f" Contains zeros or negatives: {np.any(se <= 0)}") 

631 print("\n Returning basic accuracy metrics only") 

632 

633 return { 

634 "rmse": float(rmse), 

635 "mae": float(mae), 

636 "nll": float("nan"), 

637 "miscalibration_area": float("nan"), 

638 "z_score_mean": float("nan"), 

639 "z_score_std": float("nan"), 

640 } 

641 

642 # NLL (Eq. 6) 

643 nll = 0.5 * np.mean(errs**2 / se**2 + np.log(2 * np.pi * se**2)) 

644 

645 # Standardized residuals (z-scores) 

646 z_scores = errs / se 

647 z_mean = np.mean(z_scores) 

648 z_std = np.std(z_scores) 

649 

650 # Miscalibration area (Kellner Fig. 2) 

651 # Measures deviation of empirical CDF from theoretical Gaussian CDF 

652 sorted_z = np.sort(z_scores) 

653 empirical_cdf = np.arange(1, len(sorted_z) + 1) / len(sorted_z) 

654 theoretical_cdf = jax.scipy.stats.norm.cdf(sorted_z) 

655 miscalibration_area = np.mean(np.abs(empirical_cdf - theoretical_cdf)) 

656 

657 metrics = { 

658 "rmse": float(rmse), 

659 "mae": float(mae), 

660 "nll": float(nll), 

661 "miscalibration_area": float(miscalibration_area), 

662 "z_score_mean": float(z_mean), 

663 "z_score_std": float(z_std), 

664 } 

665 

666 return metrics 

667 

668 def print_metrics(self, X, y): 

669 """Print uncertainty metrics in human-readable format. 

670 

671 Args: 

672 X: Input features. 

673 y: True target values. 

674 """ 

675 metrics = self.uncertainty_metrics(X, y) 

676 

677 print("\n" + "=" * 50) 

678 print("UNCERTAINTY QUANTIFICATION METRICS") 

679 print("=" * 50) 

680 print("Prediction Accuracy:") 

681 print(f" RMSE: {metrics['rmse']:.6f}") 

682 print(f" MAE: {metrics['mae']:.6f}") 

683 

684 # Check if uncertainty metrics are available 

685 has_uncertainty = not np.isnan(metrics["nll"]) 

686 

687 if has_uncertainty: 

688 print("\nUncertainty Quality:") 

689 print(f" NLL: {metrics['nll']:.6f} (lower is better)") 

690 print(f" Miscalibration Area: {metrics['miscalibration_area']:.6f} (lower is better)") 

691 print("\nCalibration Diagnostics:") 

692 print(f" Z-score mean: {metrics['z_score_mean']:.4f} (ideal: 0)") 

693 print(f" Z-score std: {metrics['z_score_std']:.4f} (ideal: 1)") 

694 

695 # Interpret calibration 

696 if abs(metrics["z_score_mean"]) < 0.1 and abs(metrics["z_score_std"] - 1) < 0.2: 

697 print(" ✓ Well-calibrated uncertainties") 

698 elif metrics["z_score_std"] < 0.8: 

699 print(" ⚠ Overconfident (uncertainties too small)") 

700 elif metrics["z_score_std"] > 1.2: 

701 print(" ⚠ Underconfident (uncertainties too large)") 

702 else: 

703 print(" ⚠ Miscalibrated") 

704 else: 

705 print("\nUncertainty Quality:") 

706 print(" NLL: N/A (ensemble collapsed)") 

707 print(" Miscalibration Area: N/A") 

708 print("\nCalibration Diagnostics:") 

709 print(" Z-score mean: N/A") 

710 print(" Z-score std: N/A") 

711 print("\n ✗ Uncertainty estimates not available due to ensemble collapse") 

712 print(" ➜ See warnings above for diagnostic information and suggested fixes") 

713 

714 print("=" * 50 + "\n")