Coverage for src/pycse/sklearn/nngmm.py: 0.00%
182 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
1"""Neural network with Gaussian Mixture Model regression.
3Use a neural network as a nonlinear feature generator, then use a GMM
4regression for the last layer to get uncertainty quantification.
6The GMM approach can capture multimodal distributions and complex
7uncertainty patterns, making it suitable for heteroscedastic noise.
9Example:
11 import numpy as np
12 from sklearn.neural_network import MLPRegressor
13 from sklearn.model_selection import train_test_split
14 from pycse.sklearn.nngmm import NeuralNetworkGMM
16 # Generate data with heteroscedastic noise
17 X = np.random.randn(200, 5)
18 y = np.sum(X**2, axis=1) + (0.1 + 0.5*X[:, 0]**2) * np.random.randn(200)
19 X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
21 # Setup neural network
22 nn = MLPRegressor(
23 hidden_layer_sizes=(20, 200),
24 activation='relu',
25 solver='lbfgs',
26 max_iter=1000
27 )
29 # Create and train NNGMM
30 nngmm = NeuralNetworkGMM(nn, n_components=1)
31 nngmm.fit(X_train, y_train, val_X=X_val, val_y=y_val)
33 # Get predictions with uncertainty
34 y_pred, y_std = nngmm.predict(X_val, return_std=True)
36 # Visualize (for 1D input)
37 nngmm.plot(X, y)
39 # Print diagnostics
40 nngmm.report()
41 nngmm.print_metrics(X_val, y_val)
43Requires: scikit-learn, numpy, matplotlib, gmr
44"""
46from sklearn.base import BaseEstimator, RegressorMixin
47from sklearn.neural_network._base import ACTIVATIONS
48import numpy as np
49import matplotlib.pyplot as plt
50from gmr import GMM
53class NeuralNetworkGMM(BaseEstimator, RegressorMixin):
54 """sklearn-compatible neural network with GMM regression in last layer.
56 The idea is you fit a neural network and replace the last linear layer
57 with a Gaussian Mixture Model regressor to estimate uncertainty.
59 The GMM can capture complex, multimodal uncertainty distributions.
60 """
62 def __init__(self, nn, n_components=1, n_samples=500):
63 """Initialize the Neural Network GMM Regressor.
65 Args:
66 nn: An sklearn.neural_network.MLPRegressor instance
67 n_components: Number of GMM components (default: 1)
68 n_samples: Number of samples for uncertainty estimation (default: 500)
69 """
70 self.nn = nn
71 self.n_components = n_components
72 self.n_samples = n_samples
73 self.calibration_factor = 1.0 # For post-hoc calibration
75 def _feat(self, X):
76 """Return neural network features for X.
78 Extracts features from the last hidden layer of the neural network.
80 Args:
81 X: Input features, shape (n_samples, n_features)
83 Returns:
84 Features from last hidden layer, shape (n_samples, hidden_size)
85 """
86 import warnings
88 weights = self.nn.coefs_
89 biases = self.nn.intercepts_
91 # Suppress numerical warnings during feature extraction
92 with warnings.catch_warnings():
93 warnings.simplefilter("ignore", RuntimeWarning)
95 # Get the output of last hidden layer
96 feat = X @ weights[0] + biases[0]
97 ACTIVATIONS[self.nn.activation](feat) # works in place
98 for i in range(1, len(weights) - 1):
99 feat = feat @ weights[i] + biases[i]
100 ACTIVATIONS[self.nn.activation](feat)
102 return feat
104 def fit(self, X, y, val_X=None, val_y=None):
105 """Fit the regressor to X, y.
107 This first fits the NeuralNetwork instance. Then it gets the features
108 from the output layer and uses those in the GMM regressor.
110 Args:
111 X: Training features, shape (n_samples, n_features)
112 y: Training targets, shape (n_samples,)
113 val_X: Optional validation features for post-hoc calibration
114 val_y: Optional validation targets for post-hoc calibration
116 Returns:
117 self: Fitted model
118 """
119 import warnings
121 # Suppress ALL numerical warnings during training
122 # (including from sklearn and scipy internals)
123 with warnings.catch_warnings():
124 warnings.filterwarnings("ignore", category=RuntimeWarning)
126 # Initial fit of neural network
127 self.nn.fit(X, y)
129 # Create GMM from features and targets
130 self.gmm = GMM(n_components=self.n_components)
131 features = self._feat(X)
132 self.gmm.from_samples(np.hstack([features, y[:, None]]))
134 # Post-hoc calibration on validation set
135 if val_X is not None and val_y is not None:
136 self._calibrate(val_X, val_y)
138 return self
140 def _calibrate(self, X, y):
141 """Perform post-hoc calibration of uncertainties.
143 Computes a calibration factor that rescales predicted uncertainties
144 to better match empirical errors on the validation set.
146 Args:
147 X: Validation features
148 y: Validation targets
149 """
150 # Get predictions and uncertainties
151 y_pred, y_std = self.predict(X, return_std=True)
152 # Flatten to avoid broadcasting issues (y_pred is 2D, y is 1D)
153 errors = np.asarray(y).ravel() - y_pred.ravel()
155 # Check for collapsed uncertainties
156 mean_std = np.mean(y_std)
157 if mean_std < 1e-8:
158 print("\n⚠ WARNING: GMM has collapsed to deterministic predictions!")
159 print(f" Mean uncertainty: {mean_std:.2e} (nearly zero)")
160 print(f" Uncertainty spread: {y_std.min():.2e} to {y_std.max():.2e}")
161 print("\n Possible causes:")
162 print(f" - GMM components: {self.n_components} (try increasing)")
163 print(" - Neural network overfit (reduce training iterations)")
164 print(" - Too few samples for uncertainty estimation")
165 print("\n Skipping calibration (using α = 1.0)")
166 self.calibration_factor = 1.0
167 return
169 # Calibration factor: ratio of empirical to predicted variance
170 alpha_sq = np.mean(errors**2) / np.mean(y_std**2)
171 self.calibration_factor = float(np.sqrt(alpha_sq))
173 # Check for numerical issues
174 if not np.isfinite(self.calibration_factor):
175 print(f"\n⚠ WARNING: Calibration failed (α = {self.calibration_factor})")
176 print(f" Mean error²: {np.mean(errors**2):.6f}")
177 print(f" Mean σ²: {np.mean(y_std**2):.6f}")
178 print(" Skipping calibration (using α = 1.0)")
179 self.calibration_factor = 1.0
180 return
182 print(f"\nCalibration factor α = {self.calibration_factor:.4f}")
183 if 0.9 <= self.calibration_factor <= 1.1:
184 print(" ✓ Model is well-calibrated")
186 def predict(self, X, return_std=False):
187 """Predict output values for X.
189 Args:
190 X: Input features, shape (n_samples, n_features)
191 return_std: If True, also return standard deviation for each prediction
193 Returns:
194 y_pred: Predicted values, shape (n_samples,)
195 y_std: Standard deviations (if return_std=True), shape (n_samples,)
196 """
197 import warnings
199 # Suppress ALL numerical warnings during prediction
200 with warnings.catch_warnings():
201 warnings.filterwarnings("ignore", category=RuntimeWarning)
203 # Get features from neural network
204 feat = self._feat(X)
206 # GMM prediction indices (input dimensions)
207 inds = np.arange(0, feat.shape[1])
209 # Predict mean
210 y = self.gmm.predict(inds, feat)
212 if return_std:
213 se = []
214 for f in feat:
215 # Condition GMM on the features
216 g = self.gmm.condition(np.arange(len(f)), f)
217 # Sample from conditional distribution
218 samples = g.sample(self.n_samples)
219 # Compute standard deviation
220 se.append(np.std(samples))
222 # Apply calibration factor
223 return y, self.calibration_factor * np.array(se)
224 else:
225 return y
227 def report(self):
228 """Print model diagnostics and configuration."""
229 print("\n" + "=" * 50)
230 print("NEURAL NETWORK GMM MODEL")
231 print("=" * 50)
232 print("Neural Network:")
233 print(f" Architecture: {self.nn.hidden_layer_sizes}")
234 print(f" Activation: {self.nn.activation}")
235 print(f" Solver: {self.nn.solver}")
236 print(f" Iterations: {self.nn.n_iter_}")
237 print("\nGMM Configuration:")
238 print(f" Components: {self.n_components}")
239 print(f" Samples for UQ: {self.n_samples}")
240 print("\nCalibration:")
241 print(f" Calibration factor α: {self.calibration_factor:.4f}")
242 print("=" * 50 + "\n")
244 def plot(self, X, y, ax=None):
245 """Plot predictions with uncertainty bands.
247 Only works for 1D input (or shows first feature if multidimensional).
249 Args:
250 X: Input features
251 y: True targets
252 ax: Optional matplotlib axes object
253 """
254 if ax is None:
255 fig, ax = plt.subplots(figsize=(10, 6))
257 # Get predictions
258 y_pred, y_std = self.predict(X, return_std=True)
260 # Flatten predictions to 1D (GMM returns 2D arrays)
261 y_pred = y_pred.ravel()
262 y_std = y_std.ravel()
264 # Handle multi-dimensional input (use first feature)
265 if X.ndim > 1 and X.shape[1] > 1:
266 X_plot = X[:, 0]
267 print(f"Note: Plotting first feature only (input has {X.shape[1]} features)")
268 else:
269 X_plot = X.ravel()
271 # Sort by X for smooth line plots
272 sort_idx = np.argsort(X_plot)
273 X_sorted = X_plot[sort_idx]
274 y_pred_sorted = y_pred[sort_idx]
275 y_std_sorted = y_std[sort_idx]
277 # Plot uncertainty band (95% CI)
278 ax.fill_between(
279 X_sorted,
280 y_pred_sorted - 2 * y_std_sorted,
281 y_pred_sorted + 2 * y_std_sorted,
282 alpha=0.3,
283 color="red",
284 label="±2σ (95% CI)",
285 zorder=1,
286 )
288 # Plot mean prediction
289 ax.plot(X_sorted, y_pred_sorted, "r-", linewidth=2.5, label="GMM prediction", zorder=3)
291 # Plot data points
292 ax.scatter(X_plot, y, alpha=0.5, s=30, color="blue", label="Data", zorder=4)
294 # Set y-axis limits: ±20% of data range
295 y_data_min = np.min(y)
296 y_data_max = np.max(y)
297 y_range = y_data_max - y_data_min
298 ax.set_ylim(y_data_min - 0.2 * y_range, y_data_max + 0.2 * y_range)
300 ax.set_xlabel("X" if X.ndim == 1 else "X[0]", fontsize=12)
301 ax.set_ylabel("y", fontsize=12)
302 ax.set_title(
303 "Neural Network GMM: Predictions with Uncertainty", fontsize=13, fontweight="bold"
304 )
305 ax.legend(fontsize=10)
306 ax.grid(True, alpha=0.3)
308 plt.tight_layout()
309 return ax
311 def uncertainty_metrics(self, X, y):
312 """Compute uncertainty quantification metrics.
314 Evaluates how well the model's uncertainty estimates match
315 the empirical errors.
317 Args:
318 X: Input features
319 y: True targets
321 Returns:
322 dict: Dictionary containing metrics:
323 - rmse: Root mean squared error
324 - mae: Mean absolute error
325 - nll: Negative log-likelihood
326 - miscalibration_area: Area between calibration curve and ideal
327 - z_score_mean: Mean of z-scores (should be ~0)
328 - z_score_std: Std of z-scores (should be ~1)
329 """
330 y_pred, y_std = self.predict(X, return_std=True)
331 errors = y - y_pred
333 # Basic accuracy
334 rmse = float(np.sqrt(np.mean(errors**2)))
335 mae = float(np.mean(np.abs(errors)))
337 # Check for collapsed uncertainties
338 mean_se = np.mean(y_std)
339 if mean_se < 1e-8:
340 print("\n⚠ WARNING: Cannot compute uncertainty metrics - GMM has collapsed!")
341 print(f" Mean uncertainty: {mean_se:.2e} (nearly zero)")
342 print(" This causes division by zero in metric calculations.")
343 print("\n Returning basic accuracy metrics only (NLL, Z-scores unavailable)")
345 return {
346 "rmse": rmse,
347 "mae": mae,
348 "nll": float("nan"),
349 "miscalibration_area": float("nan"),
350 "z_score_mean": float("nan"),
351 "z_score_std": float("nan"),
352 }
354 # Check for any numerical issues
355 if not np.all(np.isfinite(y_std)) or np.any(y_std <= 0):
356 print("\n⚠ WARNING: Invalid uncertainty values detected!")
357 print(f" Contains NaN: {np.any(np.isnan(y_std))}")
358 print(f" Contains inf: {np.any(np.isinf(y_std))}")
359 print(f" Contains zeros or negatives: {np.any(y_std <= 0)}")
360 print("\n Returning basic accuracy metrics only")
362 return {
363 "rmse": rmse,
364 "mae": mae,
365 "nll": float("nan"),
366 "miscalibration_area": float("nan"),
367 "z_score_mean": float("nan"),
368 "z_score_std": float("nan"),
369 }
371 # Negative log-likelihood (Gaussian)
372 nll = float(0.5 * np.mean((errors / y_std) ** 2 + np.log(2 * np.pi * y_std**2)))
374 # Z-scores (standardized residuals)
375 z_scores = errors / y_std
376 z_mean = float(np.mean(z_scores))
377 z_std = float(np.std(z_scores))
379 # Miscalibration area (empirical calibration curve)
380 # Sort by predicted uncertainty
381 sorted_indices = np.argsort(y_std)
382 sorted_errors = np.abs(errors[sorted_indices])
383 sorted_stds = y_std[sorted_indices]
385 # Compute cumulative calibration
386 n = len(sorted_errors)
387 expected_coverage = np.linspace(0, 1, n)
388 actual_coverage = np.array(
389 [np.mean(sorted_errors <= k * sorted_stds) for k in np.linspace(0, 3, n)]
390 )
392 # Compute miscalibration area
393 miscalibration_area = float(np.mean(np.abs(actual_coverage - expected_coverage)))
395 return {
396 "rmse": rmse,
397 "mae": mae,
398 "nll": nll,
399 "miscalibration_area": miscalibration_area,
400 "z_score_mean": z_mean,
401 "z_score_std": z_std,
402 }
404 def print_metrics(self, X, y):
405 """Print comprehensive uncertainty metrics.
407 Args:
408 X: Input features
409 y: True targets
410 """
411 metrics = self.uncertainty_metrics(X, y)
413 print("\n" + "=" * 50)
414 print("UNCERTAINTY QUANTIFICATION METRICS")
415 print("=" * 50)
416 print("Prediction Accuracy:")
417 print(f" RMSE: {metrics['rmse']:.6f}")
418 print(f" MAE: {metrics['mae']:.6f}")
420 # Check if uncertainty metrics are available
421 has_uncertainty = not np.isnan(metrics["nll"])
423 if has_uncertainty:
424 print("\nUncertainty Quality:")
425 print(f" NLL: {metrics['nll']:.6f} (lower is better)")
426 print(f" Miscalibration Area: {metrics['miscalibration_area']:.6f} (lower is better)")
427 print("\nCalibration Diagnostics:")
428 print(f" Z-score mean: {metrics['z_score_mean']:.4f} (ideal: 0)")
429 print(f" Z-score std: {metrics['z_score_std']:.4f} (ideal: 1)")
431 # Interpret calibration
432 if abs(metrics["z_score_mean"]) < 0.1 and abs(metrics["z_score_std"] - 1) < 0.2:
433 print(" ✓ Well-calibrated uncertainties")
434 elif metrics["z_score_std"] < 0.8:
435 print(" ⚠ Overconfident (uncertainties too small)")
436 elif metrics["z_score_std"] > 1.2:
437 print(" ⚠ Underconfident (uncertainties too large)")
438 else:
439 print(" ⚠ Miscalibrated")
440 else:
441 print("\nUncertainty Quality:")
442 print(" NLL: N/A (GMM collapsed)")
443 print(" Miscalibration Area: N/A")
444 print("\nCalibration Diagnostics:")
445 print(" Z-score mean: N/A")
446 print(" Z-score std: N/A")
447 print("\n ✗ Uncertainty estimates not available due to collapsed GMM")
448 print(" ➜ Try increasing n_components or reducing neural network training")
450 print("=" * 50 + "\n")