Coverage for src/pycse/sklearn/nngmm.py: 0.00%

182 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 16:23 -0400

1"""Neural network with Gaussian Mixture Model regression. 

2 

3Use a neural network as a nonlinear feature generator, then use a GMM 

4regression for the last layer to get uncertainty quantification. 

5 

6The GMM approach can capture multimodal distributions and complex 

7uncertainty patterns, making it suitable for heteroscedastic noise. 

8 

9Example: 

10 

11 import numpy as np 

12 from sklearn.neural_network import MLPRegressor 

13 from sklearn.model_selection import train_test_split 

14 from pycse.sklearn.nngmm import NeuralNetworkGMM 

15 

16 # Generate data with heteroscedastic noise 

17 X = np.random.randn(200, 5) 

18 y = np.sum(X**2, axis=1) + (0.1 + 0.5*X[:, 0]**2) * np.random.randn(200) 

19 X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2) 

20 

21 # Setup neural network 

22 nn = MLPRegressor( 

23 hidden_layer_sizes=(20, 200), 

24 activation='relu', 

25 solver='lbfgs', 

26 max_iter=1000 

27 ) 

28 

29 # Create and train NNGMM 

30 nngmm = NeuralNetworkGMM(nn, n_components=1) 

31 nngmm.fit(X_train, y_train, val_X=X_val, val_y=y_val) 

32 

33 # Get predictions with uncertainty 

34 y_pred, y_std = nngmm.predict(X_val, return_std=True) 

35 

36 # Visualize (for 1D input) 

37 nngmm.plot(X, y) 

38 

39 # Print diagnostics 

40 nngmm.report() 

41 nngmm.print_metrics(X_val, y_val) 

42 

43Requires: scikit-learn, numpy, matplotlib, gmr 

44""" 

45 

46from sklearn.base import BaseEstimator, RegressorMixin 

47from sklearn.neural_network._base import ACTIVATIONS 

48import numpy as np 

49import matplotlib.pyplot as plt 

50from gmr import GMM 

51 

52 

53class NeuralNetworkGMM(BaseEstimator, RegressorMixin): 

54 """sklearn-compatible neural network with GMM regression in last layer. 

55 

56 The idea is you fit a neural network and replace the last linear layer 

57 with a Gaussian Mixture Model regressor to estimate uncertainty. 

58 

59 The GMM can capture complex, multimodal uncertainty distributions. 

60 """ 

61 

62 def __init__(self, nn, n_components=1, n_samples=500): 

63 """Initialize the Neural Network GMM Regressor. 

64 

65 Args: 

66 nn: An sklearn.neural_network.MLPRegressor instance 

67 n_components: Number of GMM components (default: 1) 

68 n_samples: Number of samples for uncertainty estimation (default: 500) 

69 """ 

70 self.nn = nn 

71 self.n_components = n_components 

72 self.n_samples = n_samples 

73 self.calibration_factor = 1.0 # For post-hoc calibration 

74 

75 def _feat(self, X): 

76 """Return neural network features for X. 

77 

78 Extracts features from the last hidden layer of the neural network. 

79 

80 Args: 

81 X: Input features, shape (n_samples, n_features) 

82 

83 Returns: 

84 Features from last hidden layer, shape (n_samples, hidden_size) 

85 """ 

86 import warnings 

87 

88 weights = self.nn.coefs_ 

89 biases = self.nn.intercepts_ 

90 

91 # Suppress numerical warnings during feature extraction 

92 with warnings.catch_warnings(): 

93 warnings.simplefilter("ignore", RuntimeWarning) 

94 

95 # Get the output of last hidden layer 

96 feat = X @ weights[0] + biases[0] 

97 ACTIVATIONS[self.nn.activation](feat) # works in place 

98 for i in range(1, len(weights) - 1): 

99 feat = feat @ weights[i] + biases[i] 

100 ACTIVATIONS[self.nn.activation](feat) 

101 

102 return feat 

103 

104 def fit(self, X, y, val_X=None, val_y=None): 

105 """Fit the regressor to X, y. 

106 

107 This first fits the NeuralNetwork instance. Then it gets the features 

108 from the output layer and uses those in the GMM regressor. 

109 

110 Args: 

111 X: Training features, shape (n_samples, n_features) 

112 y: Training targets, shape (n_samples,) 

113 val_X: Optional validation features for post-hoc calibration 

114 val_y: Optional validation targets for post-hoc calibration 

115 

116 Returns: 

117 self: Fitted model 

118 """ 

119 import warnings 

120 

121 # Suppress ALL numerical warnings during training 

122 # (including from sklearn and scipy internals) 

123 with warnings.catch_warnings(): 

124 warnings.filterwarnings("ignore", category=RuntimeWarning) 

125 

126 # Initial fit of neural network 

127 self.nn.fit(X, y) 

128 

129 # Create GMM from features and targets 

130 self.gmm = GMM(n_components=self.n_components) 

131 features = self._feat(X) 

132 self.gmm.from_samples(np.hstack([features, y[:, None]])) 

133 

134 # Post-hoc calibration on validation set 

135 if val_X is not None and val_y is not None: 

136 self._calibrate(val_X, val_y) 

137 

138 return self 

139 

140 def _calibrate(self, X, y): 

141 """Perform post-hoc calibration of uncertainties. 

142 

143 Computes a calibration factor that rescales predicted uncertainties 

144 to better match empirical errors on the validation set. 

145 

146 Args: 

147 X: Validation features 

148 y: Validation targets 

149 """ 

150 # Get predictions and uncertainties 

151 y_pred, y_std = self.predict(X, return_std=True) 

152 # Flatten to avoid broadcasting issues (y_pred is 2D, y is 1D) 

153 errors = np.asarray(y).ravel() - y_pred.ravel() 

154 

155 # Check for collapsed uncertainties 

156 mean_std = np.mean(y_std) 

157 if mean_std < 1e-8: 

158 print("\n⚠ WARNING: GMM has collapsed to deterministic predictions!") 

159 print(f" Mean uncertainty: {mean_std:.2e} (nearly zero)") 

160 print(f" Uncertainty spread: {y_std.min():.2e} to {y_std.max():.2e}") 

161 print("\n Possible causes:") 

162 print(f" - GMM components: {self.n_components} (try increasing)") 

163 print(" - Neural network overfit (reduce training iterations)") 

164 print(" - Too few samples for uncertainty estimation") 

165 print("\n Skipping calibration (using α = 1.0)") 

166 self.calibration_factor = 1.0 

167 return 

168 

169 # Calibration factor: ratio of empirical to predicted variance 

170 alpha_sq = np.mean(errors**2) / np.mean(y_std**2) 

171 self.calibration_factor = float(np.sqrt(alpha_sq)) 

172 

173 # Check for numerical issues 

174 if not np.isfinite(self.calibration_factor): 

175 print(f"\n⚠ WARNING: Calibration failed (α = {self.calibration_factor})") 

176 print(f" Mean error²: {np.mean(errors**2):.6f}") 

177 print(f" Mean σ²: {np.mean(y_std**2):.6f}") 

178 print(" Skipping calibration (using α = 1.0)") 

179 self.calibration_factor = 1.0 

180 return 

181 

182 print(f"\nCalibration factor α = {self.calibration_factor:.4f}") 

183 if 0.9 <= self.calibration_factor <= 1.1: 

184 print(" ✓ Model is well-calibrated") 

185 

186 def predict(self, X, return_std=False): 

187 """Predict output values for X. 

188 

189 Args: 

190 X: Input features, shape (n_samples, n_features) 

191 return_std: If True, also return standard deviation for each prediction 

192 

193 Returns: 

194 y_pred: Predicted values, shape (n_samples,) 

195 y_std: Standard deviations (if return_std=True), shape (n_samples,) 

196 """ 

197 import warnings 

198 

199 # Suppress ALL numerical warnings during prediction 

200 with warnings.catch_warnings(): 

201 warnings.filterwarnings("ignore", category=RuntimeWarning) 

202 

203 # Get features from neural network 

204 feat = self._feat(X) 

205 

206 # GMM prediction indices (input dimensions) 

207 inds = np.arange(0, feat.shape[1]) 

208 

209 # Predict mean 

210 y = self.gmm.predict(inds, feat) 

211 

212 if return_std: 

213 se = [] 

214 for f in feat: 

215 # Condition GMM on the features 

216 g = self.gmm.condition(np.arange(len(f)), f) 

217 # Sample from conditional distribution 

218 samples = g.sample(self.n_samples) 

219 # Compute standard deviation 

220 se.append(np.std(samples)) 

221 

222 # Apply calibration factor 

223 return y, self.calibration_factor * np.array(se) 

224 else: 

225 return y 

226 

227 def report(self): 

228 """Print model diagnostics and configuration.""" 

229 print("\n" + "=" * 50) 

230 print("NEURAL NETWORK GMM MODEL") 

231 print("=" * 50) 

232 print("Neural Network:") 

233 print(f" Architecture: {self.nn.hidden_layer_sizes}") 

234 print(f" Activation: {self.nn.activation}") 

235 print(f" Solver: {self.nn.solver}") 

236 print(f" Iterations: {self.nn.n_iter_}") 

237 print("\nGMM Configuration:") 

238 print(f" Components: {self.n_components}") 

239 print(f" Samples for UQ: {self.n_samples}") 

240 print("\nCalibration:") 

241 print(f" Calibration factor α: {self.calibration_factor:.4f}") 

242 print("=" * 50 + "\n") 

243 

244 def plot(self, X, y, ax=None): 

245 """Plot predictions with uncertainty bands. 

246 

247 Only works for 1D input (or shows first feature if multidimensional). 

248 

249 Args: 

250 X: Input features 

251 y: True targets 

252 ax: Optional matplotlib axes object 

253 """ 

254 if ax is None: 

255 fig, ax = plt.subplots(figsize=(10, 6)) 

256 

257 # Get predictions 

258 y_pred, y_std = self.predict(X, return_std=True) 

259 

260 # Flatten predictions to 1D (GMM returns 2D arrays) 

261 y_pred = y_pred.ravel() 

262 y_std = y_std.ravel() 

263 

264 # Handle multi-dimensional input (use first feature) 

265 if X.ndim > 1 and X.shape[1] > 1: 

266 X_plot = X[:, 0] 

267 print(f"Note: Plotting first feature only (input has {X.shape[1]} features)") 

268 else: 

269 X_plot = X.ravel() 

270 

271 # Sort by X for smooth line plots 

272 sort_idx = np.argsort(X_plot) 

273 X_sorted = X_plot[sort_idx] 

274 y_pred_sorted = y_pred[sort_idx] 

275 y_std_sorted = y_std[sort_idx] 

276 

277 # Plot uncertainty band (95% CI) 

278 ax.fill_between( 

279 X_sorted, 

280 y_pred_sorted - 2 * y_std_sorted, 

281 y_pred_sorted + 2 * y_std_sorted, 

282 alpha=0.3, 

283 color="red", 

284 label="±2σ (95% CI)", 

285 zorder=1, 

286 ) 

287 

288 # Plot mean prediction 

289 ax.plot(X_sorted, y_pred_sorted, "r-", linewidth=2.5, label="GMM prediction", zorder=3) 

290 

291 # Plot data points 

292 ax.scatter(X_plot, y, alpha=0.5, s=30, color="blue", label="Data", zorder=4) 

293 

294 # Set y-axis limits: ±20% of data range 

295 y_data_min = np.min(y) 

296 y_data_max = np.max(y) 

297 y_range = y_data_max - y_data_min 

298 ax.set_ylim(y_data_min - 0.2 * y_range, y_data_max + 0.2 * y_range) 

299 

300 ax.set_xlabel("X" if X.ndim == 1 else "X[0]", fontsize=12) 

301 ax.set_ylabel("y", fontsize=12) 

302 ax.set_title( 

303 "Neural Network GMM: Predictions with Uncertainty", fontsize=13, fontweight="bold" 

304 ) 

305 ax.legend(fontsize=10) 

306 ax.grid(True, alpha=0.3) 

307 

308 plt.tight_layout() 

309 return ax 

310 

311 def uncertainty_metrics(self, X, y): 

312 """Compute uncertainty quantification metrics. 

313 

314 Evaluates how well the model's uncertainty estimates match 

315 the empirical errors. 

316 

317 Args: 

318 X: Input features 

319 y: True targets 

320 

321 Returns: 

322 dict: Dictionary containing metrics: 

323 - rmse: Root mean squared error 

324 - mae: Mean absolute error 

325 - nll: Negative log-likelihood 

326 - miscalibration_area: Area between calibration curve and ideal 

327 - z_score_mean: Mean of z-scores (should be ~0) 

328 - z_score_std: Std of z-scores (should be ~1) 

329 """ 

330 y_pred, y_std = self.predict(X, return_std=True) 

331 errors = y - y_pred 

332 

333 # Basic accuracy 

334 rmse = float(np.sqrt(np.mean(errors**2))) 

335 mae = float(np.mean(np.abs(errors))) 

336 

337 # Check for collapsed uncertainties 

338 mean_se = np.mean(y_std) 

339 if mean_se < 1e-8: 

340 print("\n⚠ WARNING: Cannot compute uncertainty metrics - GMM has collapsed!") 

341 print(f" Mean uncertainty: {mean_se:.2e} (nearly zero)") 

342 print(" This causes division by zero in metric calculations.") 

343 print("\n Returning basic accuracy metrics only (NLL, Z-scores unavailable)") 

344 

345 return { 

346 "rmse": rmse, 

347 "mae": mae, 

348 "nll": float("nan"), 

349 "miscalibration_area": float("nan"), 

350 "z_score_mean": float("nan"), 

351 "z_score_std": float("nan"), 

352 } 

353 

354 # Check for any numerical issues 

355 if not np.all(np.isfinite(y_std)) or np.any(y_std <= 0): 

356 print("\n⚠ WARNING: Invalid uncertainty values detected!") 

357 print(f" Contains NaN: {np.any(np.isnan(y_std))}") 

358 print(f" Contains inf: {np.any(np.isinf(y_std))}") 

359 print(f" Contains zeros or negatives: {np.any(y_std <= 0)}") 

360 print("\n Returning basic accuracy metrics only") 

361 

362 return { 

363 "rmse": rmse, 

364 "mae": mae, 

365 "nll": float("nan"), 

366 "miscalibration_area": float("nan"), 

367 "z_score_mean": float("nan"), 

368 "z_score_std": float("nan"), 

369 } 

370 

371 # Negative log-likelihood (Gaussian) 

372 nll = float(0.5 * np.mean((errors / y_std) ** 2 + np.log(2 * np.pi * y_std**2))) 

373 

374 # Z-scores (standardized residuals) 

375 z_scores = errors / y_std 

376 z_mean = float(np.mean(z_scores)) 

377 z_std = float(np.std(z_scores)) 

378 

379 # Miscalibration area (empirical calibration curve) 

380 # Sort by predicted uncertainty 

381 sorted_indices = np.argsort(y_std) 

382 sorted_errors = np.abs(errors[sorted_indices]) 

383 sorted_stds = y_std[sorted_indices] 

384 

385 # Compute cumulative calibration 

386 n = len(sorted_errors) 

387 expected_coverage = np.linspace(0, 1, n) 

388 actual_coverage = np.array( 

389 [np.mean(sorted_errors <= k * sorted_stds) for k in np.linspace(0, 3, n)] 

390 ) 

391 

392 # Compute miscalibration area 

393 miscalibration_area = float(np.mean(np.abs(actual_coverage - expected_coverage))) 

394 

395 return { 

396 "rmse": rmse, 

397 "mae": mae, 

398 "nll": nll, 

399 "miscalibration_area": miscalibration_area, 

400 "z_score_mean": z_mean, 

401 "z_score_std": z_std, 

402 } 

403 

404 def print_metrics(self, X, y): 

405 """Print comprehensive uncertainty metrics. 

406 

407 Args: 

408 X: Input features 

409 y: True targets 

410 """ 

411 metrics = self.uncertainty_metrics(X, y) 

412 

413 print("\n" + "=" * 50) 

414 print("UNCERTAINTY QUANTIFICATION METRICS") 

415 print("=" * 50) 

416 print("Prediction Accuracy:") 

417 print(f" RMSE: {metrics['rmse']:.6f}") 

418 print(f" MAE: {metrics['mae']:.6f}") 

419 

420 # Check if uncertainty metrics are available 

421 has_uncertainty = not np.isnan(metrics["nll"]) 

422 

423 if has_uncertainty: 

424 print("\nUncertainty Quality:") 

425 print(f" NLL: {metrics['nll']:.6f} (lower is better)") 

426 print(f" Miscalibration Area: {metrics['miscalibration_area']:.6f} (lower is better)") 

427 print("\nCalibration Diagnostics:") 

428 print(f" Z-score mean: {metrics['z_score_mean']:.4f} (ideal: 0)") 

429 print(f" Z-score std: {metrics['z_score_std']:.4f} (ideal: 1)") 

430 

431 # Interpret calibration 

432 if abs(metrics["z_score_mean"]) < 0.1 and abs(metrics["z_score_std"] - 1) < 0.2: 

433 print(" ✓ Well-calibrated uncertainties") 

434 elif metrics["z_score_std"] < 0.8: 

435 print(" ⚠ Overconfident (uncertainties too small)") 

436 elif metrics["z_score_std"] > 1.2: 

437 print(" ⚠ Underconfident (uncertainties too large)") 

438 else: 

439 print(" ⚠ Miscalibrated") 

440 else: 

441 print("\nUncertainty Quality:") 

442 print(" NLL: N/A (GMM collapsed)") 

443 print(" Miscalibration Area: N/A") 

444 print("\nCalibration Diagnostics:") 

445 print(" Z-score mean: N/A") 

446 print(" Z-score std: N/A") 

447 print("\n ✗ Uncertainty estimates not available due to collapsed GMM") 

448 print(" ➜ Try increasing n_components or reducing neural network training") 

449 

450 print("=" * 50 + "\n")