Coverage for src/pycse/sklearn/nnbr.py: 83.23%
155 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
1"""Neural network with Bayesian Linear regression.
3Use a neural network as a nonlinear feature generator, then use Bayesian Linear
4regression for the last layer so you can also get UQ.
6Example:
8 import numpy as np
9 from sklearn.neural_network import MLPRegressor
10 from sklearn.linear_model import BayesianRidge
11 from sklearn.model_selection import train_test_split
13 # Generate data
14 X = np.random.randn(200, 5)
15 y = np.sum(X**2, axis=1) + 0.1 * np.random.randn(200)
16 X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
18 # Setup neural network
19 nn = MLPRegressor(
20 hidden_layer_sizes=(20, 200),
21 activation='relu',
22 solver='lbfgs',
23 max_iter=1000
24 )
26 # Setup Bayesian Ridge
27 br = BayesianRidge(
28 tol=1e-6,
29 fit_intercept=False,
30 compute_score=True
31 )
33 # Create and train NNBR
34 nnbr = NeuralNetworkBLR(nn, br)
35 nnbr.fit(X_train, y_train, val_X=X_val, val_y=y_val)
37 # Get predictions with uncertainty
38 y_pred, y_std = nnbr.predict(X_val, return_std=True)
40 # Visualize (for 1D input)
41 nnbr.plot(X, y)
43 # Print diagnostics
44 nnbr.report()
45 nnbr.print_metrics(X_val, y_val)
47Requires: scikit-learn, numpy, matplotlib
48"""
50from sklearn.base import BaseEstimator, RegressorMixin
51from sklearn.neural_network._base import ACTIVATIONS
52import numpy as np
53import matplotlib.pyplot as plt
56class NeuralNetworkBLR(BaseEstimator, RegressorMixin):
57 """sklearn-compatible neural network with Bayesian Regression in last layer.
59 The idea is you fit a neural network and replace the last linear layer with
60 a Bayesian linear regressor so you can estimate uncertainty.
61 """
63 def __init__(self, nn, br):
64 """Initialize the Neural Network Bayesian Linear Regressor.
66 Args:
67 nn: An sklearn.neural_network.MLPRegressor instance
68 br: An sklearn.linear_model.BayesianRidge instance
69 """
70 self.nn = nn
71 self.br = br
72 self.calibration_factor = 1.0 # For post-hoc calibration
74 def _feat(self, X):
75 """Return neural network features for X."""
76 import warnings
78 weights = self.nn.coefs_
79 biases = self.nn.intercepts_
81 # Suppress numerical warnings during feature extraction
82 with warnings.catch_warnings():
83 warnings.simplefilter("ignore", RuntimeWarning)
85 # Get the output of last hidden layer
86 feat = X @ weights[0] + biases[0]
87 ACTIVATIONS[self.nn.activation](feat) # works in place
88 for i in range(1, len(weights) - 1):
89 feat = feat @ weights[i] + biases[i]
90 ACTIVATIONS[self.nn.activation](feat)
92 return feat
94 def fit(self, X, y, val_X=None, val_y=None):
95 """Fit the regressor to X, y.
97 This first fits the NeuralNetwork instance. Then it gets the features
98 from the output layer and uses those in the Bayesian linear regressor.
100 Args:
101 X: Training features, shape (n_samples, n_features)
102 y: Training targets, shape (n_samples,)
103 val_X: Optional validation features for post-hoc calibration
104 val_y: Optional validation targets for post-hoc calibration
106 Returns:
107 self: Fitted model
108 """
109 import warnings
111 # Suppress numerical warnings during training
112 with warnings.catch_warnings():
113 warnings.simplefilter("ignore", RuntimeWarning)
115 # Stage 1: Fit neural network
116 self.nn.fit(X, y)
118 # Stage 2: Bayesian linear regression on features
119 self.br.fit(self._feat(X), y)
121 # Stage 3: Post-hoc calibration if validation data provided
122 if val_X is not None and val_y is not None:
123 self._calibrate(val_X, val_y)
125 return self
127 def _calibrate(self, X, y):
128 """Apply post-hoc calibration using validation set.
130 Rescales uncertainties so that their magnitude matches
131 actual prediction errors on the validation set.
133 Args:
134 X: Validation features
135 y: Validation targets
136 """
137 y_pred, y_std = self.predict(X, return_std=True)
138 errs = np.asarray(y).ravel() - y_pred
140 # Check for near-zero uncertainties
141 mean_std = np.mean(y_std)
142 if mean_std < 1e-8:
143 print("\n⚠ WARNING: Uncertainties are near zero!")
144 print(f" Mean uncertainty: {mean_std:.2e}")
145 print(" Skipping calibration (using α = 1.0)")
146 self.calibration_factor = 1.0
147 return
149 # Calibration factor: ratio of empirical to predicted variance
150 alpha_sq = np.mean(errs**2) / np.mean(y_std**2)
151 self.calibration_factor = float(np.sqrt(alpha_sq))
153 # Check for numerical issues
154 if not np.isfinite(self.calibration_factor):
155 print(f"\n⚠ WARNING: Calibration failed (α = {self.calibration_factor})")
156 print(" Skipping calibration (using α = 1.0)")
157 self.calibration_factor = 1.0
158 return
160 print(f"\nCalibration factor α = {self.calibration_factor:.4f}")
161 if self.calibration_factor > 1.5:
162 print(" ⚠ Model is overconfident (α > 1.5)")
163 elif self.calibration_factor < 0.7:
164 print(" ⚠ Model is underconfident (α < 0.7)")
165 else:
166 print(" ✓ Model is well-calibrated")
168 def predict(self, X, return_std=False):
169 """Predict output values for X.
171 Args:
172 X: Input features, shape (n_samples, n_features)
173 return_std: If True, return (predictions, uncertainties)
175 Returns:
176 predictions: Mean predictions, shape (n_samples,)
177 uncertainties: Standard deviation (if return_std=True), shape (n_samples,)
178 """
179 import warnings
181 # Suppress numerical warnings during prediction
182 with warnings.catch_warnings():
183 warnings.simplefilter("ignore", RuntimeWarning)
184 result = self.br.predict(self._feat(X), return_std=return_std)
186 if return_std:
187 y_pred, y_std = result
188 # Apply calibration if available
189 if hasattr(self, "calibration_factor") and self.calibration_factor != 1.0:
190 y_std = y_std * self.calibration_factor
191 return y_pred, y_std
192 else:
193 return result
195 def report(self):
196 """Print model diagnostics."""
197 print("Model Report:")
198 print(" Neural Network:")
199 print(f" Architecture: {self.nn.hidden_layer_sizes}")
200 print(f" Activation: {self.nn.activation}")
201 print(f" Solver: {self.nn.solver}")
202 print(f" Iterations: {self.nn.n_iter_}")
203 print(
204 f" Final loss: {self.nn.loss_:.6f}"
205 if hasattr(self.nn, "loss_")
206 else " Final loss: N/A"
207 )
208 print(" Bayesian Ridge:")
209 print(f" Alpha (precision): {self.br.alpha_:.6f}")
210 print(f" Lambda (noise): {self.br.lambda_:.6f}")
211 print(f" Scores available: {len(self.br.scores_) if hasattr(self.br, 'scores_') else 0}")
212 if hasattr(self, "calibration_factor"):
213 print(f" Calibration: α = {self.calibration_factor:.4f}")
215 def plot(self, X, y, ax=None):
216 """Visualize predictions with uncertainty bands.
218 Args:
219 X: Input features, shape (n_samples, n_features)
220 For 1D input, will be used as x-axis
221 y: True target values
222 ax: Matplotlib axis (optional). If None, uses current axis
224 Returns:
225 matplotlib figure object
226 """
227 if ax is None:
228 ax = plt.gca()
230 # Get predictions with calibrated uncertainties
231 y_pred, y_std = self.predict(X, return_std=True)
233 # For line plots, need to sort by X
234 X_plot = X.ravel()
235 sort_idx = np.argsort(X_plot)
236 X_sorted = X_plot[sort_idx]
237 y_pred_sorted = y_pred[sort_idx]
238 y_std_sorted = y_std[sort_idx]
240 # Plot in correct z-order (back to front):
242 # 1. Uncertainty band (background)
243 ax.fill_between(
244 X_sorted,
245 y_pred_sorted - 2 * y_std_sorted,
246 y_pred_sorted + 2 * y_std_sorted,
247 alpha=0.3,
248 color="red",
249 label="±2σ (95% CI)",
250 zorder=1,
251 )
253 # 2. Mean prediction line
254 ax.plot(X_sorted, y_pred_sorted, "r-", label="mean prediction", linewidth=2.5, zorder=3)
256 # 3. Data points (front)
257 ax.plot(X_plot, y, "b.", label="data", alpha=0.7, markersize=8, zorder=4)
259 # Set y-axis limits: ±20% of data range
260 y_data_min = np.min(y)
261 y_data_max = np.max(y)
262 y_range = y_data_max - y_data_min
263 ax.set_ylim(y_data_min - 0.2 * y_range, y_data_max + 0.2 * y_range)
265 ax.set_xlabel("X")
266 ax.set_ylabel("y")
267 ax.legend()
268 ax.set_title(f"NNBR Predictions (NN: {self.nn.hidden_layer_sizes})")
269 ax.grid(True, alpha=0.3)
271 return plt.gcf()
273 def uncertainty_metrics(self, X, y):
274 """Compute uncertainty quantification metrics.
276 Args:
277 X: Input features
278 y: True target values
280 Returns:
281 dict with keys:
282 - 'rmse': Root mean squared error
283 - 'mae': Mean absolute error
284 - 'nll': Negative log-likelihood (lower is better)
285 - 'miscalibration_area': Deviation from ideal calibration (lower is better)
286 - 'z_score_mean': Should be ~0 if well-calibrated
287 - 'z_score_std': Should be ~1 if well-calibrated
288 """
289 y_pred, y_std = self.predict(X, return_std=True)
290 y = np.asarray(y).ravel()
292 errs = y - y_pred
293 rmse = np.sqrt(np.mean(errs**2))
294 mae = np.mean(np.abs(errs))
296 # Check for near-zero uncertainties
297 mean_std = np.mean(y_std)
298 if mean_std < 1e-8:
299 print("\n⚠ WARNING: Cannot compute uncertainty metrics - uncertainties are near zero!")
300 print(f" Mean uncertainty: {mean_std:.2e}")
301 return {
302 "rmse": float(rmse),
303 "mae": float(mae),
304 "nll": float("nan"),
305 "miscalibration_area": float("nan"),
306 "z_score_mean": float("nan"),
307 "z_score_std": float("nan"),
308 }
310 # Check for numerical issues
311 if not np.all(np.isfinite(y_std)) or np.any(y_std <= 0):
312 print("\n⚠ WARNING: Invalid uncertainty values detected!")
313 return {
314 "rmse": float(rmse),
315 "mae": float(mae),
316 "nll": float("nan"),
317 "miscalibration_area": float("nan"),
318 "z_score_mean": float("nan"),
319 "z_score_std": float("nan"),
320 }
322 # NLL (negative log-likelihood)
323 nll = 0.5 * np.mean(errs**2 / y_std**2 + np.log(2 * np.pi * y_std**2))
325 # Standardized residuals (z-scores)
326 z_scores = errs / y_std
327 z_mean = np.mean(z_scores)
328 z_std = np.std(z_scores)
330 # Miscalibration area
331 sorted_z = np.sort(z_scores)
332 empirical_cdf = np.arange(1, len(sorted_z) + 1) / len(sorted_z)
333 # For theoretical CDF, use scipy if available, else simple approximation
334 try:
335 from scipy.stats import norm
337 theoretical_cdf = norm.cdf(sorted_z)
338 except ImportError:
339 # Simple approximation using error function
340 theoretical_cdf = 0.5 * (1 + np.tanh(sorted_z / np.sqrt(2)))
342 miscalibration_area = np.mean(np.abs(empirical_cdf - theoretical_cdf))
344 metrics = {
345 "rmse": float(rmse),
346 "mae": float(mae),
347 "nll": float(nll),
348 "miscalibration_area": float(miscalibration_area),
349 "z_score_mean": float(z_mean),
350 "z_score_std": float(z_std),
351 }
353 return metrics
355 def print_metrics(self, X, y):
356 """Print uncertainty metrics in human-readable format.
358 Args:
359 X: Input features
360 y: True target values
361 """
362 metrics = self.uncertainty_metrics(X, y)
364 print("\n" + "=" * 50)
365 print("UNCERTAINTY QUANTIFICATION METRICS (NNBR)")
366 print("=" * 50)
367 print("Prediction Accuracy:")
368 print(f" RMSE: {metrics['rmse']:.6f}")
369 print(f" MAE: {metrics['mae']:.6f}")
371 # Check if uncertainty metrics are available
372 has_uncertainty = not np.isnan(metrics["nll"])
374 if has_uncertainty:
375 print("\nUncertainty Quality:")
376 print(f" NLL: {metrics['nll']:.6f} (lower is better)")
377 print(f" Miscalibration Area: {metrics['miscalibration_area']:.6f} (lower is better)")
378 print("\nCalibration Diagnostics:")
379 print(f" Z-score mean: {metrics['z_score_mean']:.4f} (ideal: 0)")
380 print(f" Z-score std: {metrics['z_score_std']:.4f} (ideal: 1)")
382 # Interpret calibration
383 if abs(metrics["z_score_mean"]) < 0.1 and abs(metrics["z_score_std"] - 1) < 0.2:
384 print(" ✓ Well-calibrated uncertainties")
385 elif metrics["z_score_std"] < 0.8:
386 print(" ⚠ Overconfident (uncertainties too small)")
387 elif metrics["z_score_std"] > 1.2:
388 print(" ⚠ Underconfident (uncertainties too large)")
389 else:
390 print(" ⚠ Miscalibrated")
391 else:
392 print("\nUncertainty Quality:")
393 print(" NLL: N/A (uncertainties near zero)")
394 print("\n ✗ Uncertainty estimates not available")
396 print("=" * 50 + "\n")