Coverage for src/pycse/sklearn/nnbr.py: 83.23%

155 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 16:23 -0400

1"""Neural network with Bayesian Linear regression. 

2 

3Use a neural network as a nonlinear feature generator, then use Bayesian Linear 

4regression for the last layer so you can also get UQ. 

5 

6Example: 

7 

8 import numpy as np 

9 from sklearn.neural_network import MLPRegressor 

10 from sklearn.linear_model import BayesianRidge 

11 from sklearn.model_selection import train_test_split 

12 

13 # Generate data 

14 X = np.random.randn(200, 5) 

15 y = np.sum(X**2, axis=1) + 0.1 * np.random.randn(200) 

16 X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2) 

17 

18 # Setup neural network 

19 nn = MLPRegressor( 

20 hidden_layer_sizes=(20, 200), 

21 activation='relu', 

22 solver='lbfgs', 

23 max_iter=1000 

24 ) 

25 

26 # Setup Bayesian Ridge 

27 br = BayesianRidge( 

28 tol=1e-6, 

29 fit_intercept=False, 

30 compute_score=True 

31 ) 

32 

33 # Create and train NNBR 

34 nnbr = NeuralNetworkBLR(nn, br) 

35 nnbr.fit(X_train, y_train, val_X=X_val, val_y=y_val) 

36 

37 # Get predictions with uncertainty 

38 y_pred, y_std = nnbr.predict(X_val, return_std=True) 

39 

40 # Visualize (for 1D input) 

41 nnbr.plot(X, y) 

42 

43 # Print diagnostics 

44 nnbr.report() 

45 nnbr.print_metrics(X_val, y_val) 

46 

47Requires: scikit-learn, numpy, matplotlib 

48""" 

49 

50from sklearn.base import BaseEstimator, RegressorMixin 

51from sklearn.neural_network._base import ACTIVATIONS 

52import numpy as np 

53import matplotlib.pyplot as plt 

54 

55 

56class NeuralNetworkBLR(BaseEstimator, RegressorMixin): 

57 """sklearn-compatible neural network with Bayesian Regression in last layer. 

58 

59 The idea is you fit a neural network and replace the last linear layer with 

60 a Bayesian linear regressor so you can estimate uncertainty. 

61 """ 

62 

63 def __init__(self, nn, br): 

64 """Initialize the Neural Network Bayesian Linear Regressor. 

65 

66 Args: 

67 nn: An sklearn.neural_network.MLPRegressor instance 

68 br: An sklearn.linear_model.BayesianRidge instance 

69 """ 

70 self.nn = nn 

71 self.br = br 

72 self.calibration_factor = 1.0 # For post-hoc calibration 

73 

74 def _feat(self, X): 

75 """Return neural network features for X.""" 

76 import warnings 

77 

78 weights = self.nn.coefs_ 

79 biases = self.nn.intercepts_ 

80 

81 # Suppress numerical warnings during feature extraction 

82 with warnings.catch_warnings(): 

83 warnings.simplefilter("ignore", RuntimeWarning) 

84 

85 # Get the output of last hidden layer 

86 feat = X @ weights[0] + biases[0] 

87 ACTIVATIONS[self.nn.activation](feat) # works in place 

88 for i in range(1, len(weights) - 1): 

89 feat = feat @ weights[i] + biases[i] 

90 ACTIVATIONS[self.nn.activation](feat) 

91 

92 return feat 

93 

94 def fit(self, X, y, val_X=None, val_y=None): 

95 """Fit the regressor to X, y. 

96 

97 This first fits the NeuralNetwork instance. Then it gets the features 

98 from the output layer and uses those in the Bayesian linear regressor. 

99 

100 Args: 

101 X: Training features, shape (n_samples, n_features) 

102 y: Training targets, shape (n_samples,) 

103 val_X: Optional validation features for post-hoc calibration 

104 val_y: Optional validation targets for post-hoc calibration 

105 

106 Returns: 

107 self: Fitted model 

108 """ 

109 import warnings 

110 

111 # Suppress numerical warnings during training 

112 with warnings.catch_warnings(): 

113 warnings.simplefilter("ignore", RuntimeWarning) 

114 

115 # Stage 1: Fit neural network 

116 self.nn.fit(X, y) 

117 

118 # Stage 2: Bayesian linear regression on features 

119 self.br.fit(self._feat(X), y) 

120 

121 # Stage 3: Post-hoc calibration if validation data provided 

122 if val_X is not None and val_y is not None: 

123 self._calibrate(val_X, val_y) 

124 

125 return self 

126 

127 def _calibrate(self, X, y): 

128 """Apply post-hoc calibration using validation set. 

129 

130 Rescales uncertainties so that their magnitude matches 

131 actual prediction errors on the validation set. 

132 

133 Args: 

134 X: Validation features 

135 y: Validation targets 

136 """ 

137 y_pred, y_std = self.predict(X, return_std=True) 

138 errs = np.asarray(y).ravel() - y_pred 

139 

140 # Check for near-zero uncertainties 

141 mean_std = np.mean(y_std) 

142 if mean_std < 1e-8: 

143 print("\n⚠ WARNING: Uncertainties are near zero!") 

144 print(f" Mean uncertainty: {mean_std:.2e}") 

145 print(" Skipping calibration (using α = 1.0)") 

146 self.calibration_factor = 1.0 

147 return 

148 

149 # Calibration factor: ratio of empirical to predicted variance 

150 alpha_sq = np.mean(errs**2) / np.mean(y_std**2) 

151 self.calibration_factor = float(np.sqrt(alpha_sq)) 

152 

153 # Check for numerical issues 

154 if not np.isfinite(self.calibration_factor): 

155 print(f"\n⚠ WARNING: Calibration failed (α = {self.calibration_factor})") 

156 print(" Skipping calibration (using α = 1.0)") 

157 self.calibration_factor = 1.0 

158 return 

159 

160 print(f"\nCalibration factor α = {self.calibration_factor:.4f}") 

161 if self.calibration_factor > 1.5: 

162 print(" ⚠ Model is overconfident (α > 1.5)") 

163 elif self.calibration_factor < 0.7: 

164 print(" ⚠ Model is underconfident (α < 0.7)") 

165 else: 

166 print(" ✓ Model is well-calibrated") 

167 

168 def predict(self, X, return_std=False): 

169 """Predict output values for X. 

170 

171 Args: 

172 X: Input features, shape (n_samples, n_features) 

173 return_std: If True, return (predictions, uncertainties) 

174 

175 Returns: 

176 predictions: Mean predictions, shape (n_samples,) 

177 uncertainties: Standard deviation (if return_std=True), shape (n_samples,) 

178 """ 

179 import warnings 

180 

181 # Suppress numerical warnings during prediction 

182 with warnings.catch_warnings(): 

183 warnings.simplefilter("ignore", RuntimeWarning) 

184 result = self.br.predict(self._feat(X), return_std=return_std) 

185 

186 if return_std: 

187 y_pred, y_std = result 

188 # Apply calibration if available 

189 if hasattr(self, "calibration_factor") and self.calibration_factor != 1.0: 

190 y_std = y_std * self.calibration_factor 

191 return y_pred, y_std 

192 else: 

193 return result 

194 

195 def report(self): 

196 """Print model diagnostics.""" 

197 print("Model Report:") 

198 print(" Neural Network:") 

199 print(f" Architecture: {self.nn.hidden_layer_sizes}") 

200 print(f" Activation: {self.nn.activation}") 

201 print(f" Solver: {self.nn.solver}") 

202 print(f" Iterations: {self.nn.n_iter_}") 

203 print( 

204 f" Final loss: {self.nn.loss_:.6f}" 

205 if hasattr(self.nn, "loss_") 

206 else " Final loss: N/A" 

207 ) 

208 print(" Bayesian Ridge:") 

209 print(f" Alpha (precision): {self.br.alpha_:.6f}") 

210 print(f" Lambda (noise): {self.br.lambda_:.6f}") 

211 print(f" Scores available: {len(self.br.scores_) if hasattr(self.br, 'scores_') else 0}") 

212 if hasattr(self, "calibration_factor"): 

213 print(f" Calibration: α = {self.calibration_factor:.4f}") 

214 

215 def plot(self, X, y, ax=None): 

216 """Visualize predictions with uncertainty bands. 

217 

218 Args: 

219 X: Input features, shape (n_samples, n_features) 

220 For 1D input, will be used as x-axis 

221 y: True target values 

222 ax: Matplotlib axis (optional). If None, uses current axis 

223 

224 Returns: 

225 matplotlib figure object 

226 """ 

227 if ax is None: 

228 ax = plt.gca() 

229 

230 # Get predictions with calibrated uncertainties 

231 y_pred, y_std = self.predict(X, return_std=True) 

232 

233 # For line plots, need to sort by X 

234 X_plot = X.ravel() 

235 sort_idx = np.argsort(X_plot) 

236 X_sorted = X_plot[sort_idx] 

237 y_pred_sorted = y_pred[sort_idx] 

238 y_std_sorted = y_std[sort_idx] 

239 

240 # Plot in correct z-order (back to front): 

241 

242 # 1. Uncertainty band (background) 

243 ax.fill_between( 

244 X_sorted, 

245 y_pred_sorted - 2 * y_std_sorted, 

246 y_pred_sorted + 2 * y_std_sorted, 

247 alpha=0.3, 

248 color="red", 

249 label="±2σ (95% CI)", 

250 zorder=1, 

251 ) 

252 

253 # 2. Mean prediction line 

254 ax.plot(X_sorted, y_pred_sorted, "r-", label="mean prediction", linewidth=2.5, zorder=3) 

255 

256 # 3. Data points (front) 

257 ax.plot(X_plot, y, "b.", label="data", alpha=0.7, markersize=8, zorder=4) 

258 

259 # Set y-axis limits: ±20% of data range 

260 y_data_min = np.min(y) 

261 y_data_max = np.max(y) 

262 y_range = y_data_max - y_data_min 

263 ax.set_ylim(y_data_min - 0.2 * y_range, y_data_max + 0.2 * y_range) 

264 

265 ax.set_xlabel("X") 

266 ax.set_ylabel("y") 

267 ax.legend() 

268 ax.set_title(f"NNBR Predictions (NN: {self.nn.hidden_layer_sizes})") 

269 ax.grid(True, alpha=0.3) 

270 

271 return plt.gcf() 

272 

273 def uncertainty_metrics(self, X, y): 

274 """Compute uncertainty quantification metrics. 

275 

276 Args: 

277 X: Input features 

278 y: True target values 

279 

280 Returns: 

281 dict with keys: 

282 - 'rmse': Root mean squared error 

283 - 'mae': Mean absolute error 

284 - 'nll': Negative log-likelihood (lower is better) 

285 - 'miscalibration_area': Deviation from ideal calibration (lower is better) 

286 - 'z_score_mean': Should be ~0 if well-calibrated 

287 - 'z_score_std': Should be ~1 if well-calibrated 

288 """ 

289 y_pred, y_std = self.predict(X, return_std=True) 

290 y = np.asarray(y).ravel() 

291 

292 errs = y - y_pred 

293 rmse = np.sqrt(np.mean(errs**2)) 

294 mae = np.mean(np.abs(errs)) 

295 

296 # Check for near-zero uncertainties 

297 mean_std = np.mean(y_std) 

298 if mean_std < 1e-8: 

299 print("\n⚠ WARNING: Cannot compute uncertainty metrics - uncertainties are near zero!") 

300 print(f" Mean uncertainty: {mean_std:.2e}") 

301 return { 

302 "rmse": float(rmse), 

303 "mae": float(mae), 

304 "nll": float("nan"), 

305 "miscalibration_area": float("nan"), 

306 "z_score_mean": float("nan"), 

307 "z_score_std": float("nan"), 

308 } 

309 

310 # Check for numerical issues 

311 if not np.all(np.isfinite(y_std)) or np.any(y_std <= 0): 

312 print("\n⚠ WARNING: Invalid uncertainty values detected!") 

313 return { 

314 "rmse": float(rmse), 

315 "mae": float(mae), 

316 "nll": float("nan"), 

317 "miscalibration_area": float("nan"), 

318 "z_score_mean": float("nan"), 

319 "z_score_std": float("nan"), 

320 } 

321 

322 # NLL (negative log-likelihood) 

323 nll = 0.5 * np.mean(errs**2 / y_std**2 + np.log(2 * np.pi * y_std**2)) 

324 

325 # Standardized residuals (z-scores) 

326 z_scores = errs / y_std 

327 z_mean = np.mean(z_scores) 

328 z_std = np.std(z_scores) 

329 

330 # Miscalibration area 

331 sorted_z = np.sort(z_scores) 

332 empirical_cdf = np.arange(1, len(sorted_z) + 1) / len(sorted_z) 

333 # For theoretical CDF, use scipy if available, else simple approximation 

334 try: 

335 from scipy.stats import norm 

336 

337 theoretical_cdf = norm.cdf(sorted_z) 

338 except ImportError: 

339 # Simple approximation using error function 

340 theoretical_cdf = 0.5 * (1 + np.tanh(sorted_z / np.sqrt(2))) 

341 

342 miscalibration_area = np.mean(np.abs(empirical_cdf - theoretical_cdf)) 

343 

344 metrics = { 

345 "rmse": float(rmse), 

346 "mae": float(mae), 

347 "nll": float(nll), 

348 "miscalibration_area": float(miscalibration_area), 

349 "z_score_mean": float(z_mean), 

350 "z_score_std": float(z_std), 

351 } 

352 

353 return metrics 

354 

355 def print_metrics(self, X, y): 

356 """Print uncertainty metrics in human-readable format. 

357 

358 Args: 

359 X: Input features 

360 y: True target values 

361 """ 

362 metrics = self.uncertainty_metrics(X, y) 

363 

364 print("\n" + "=" * 50) 

365 print("UNCERTAINTY QUANTIFICATION METRICS (NNBR)") 

366 print("=" * 50) 

367 print("Prediction Accuracy:") 

368 print(f" RMSE: {metrics['rmse']:.6f}") 

369 print(f" MAE: {metrics['mae']:.6f}") 

370 

371 # Check if uncertainty metrics are available 

372 has_uncertainty = not np.isnan(metrics["nll"]) 

373 

374 if has_uncertainty: 

375 print("\nUncertainty Quality:") 

376 print(f" NLL: {metrics['nll']:.6f} (lower is better)") 

377 print(f" Miscalibration Area: {metrics['miscalibration_area']:.6f} (lower is better)") 

378 print("\nCalibration Diagnostics:") 

379 print(f" Z-score mean: {metrics['z_score_mean']:.4f} (ideal: 0)") 

380 print(f" Z-score std: {metrics['z_score_std']:.4f} (ideal: 1)") 

381 

382 # Interpret calibration 

383 if abs(metrics["z_score_mean"]) < 0.1 and abs(metrics["z_score_std"] - 1) < 0.2: 

384 print(" ✓ Well-calibrated uncertainties") 

385 elif metrics["z_score_std"] < 0.8: 

386 print(" ⚠ Overconfident (uncertainties too small)") 

387 elif metrics["z_score_std"] > 1.2: 

388 print(" ⚠ Underconfident (uncertainties too large)") 

389 else: 

390 print(" ⚠ Miscalibrated") 

391 else: 

392 print("\nUncertainty Quality:") 

393 print(" NLL: N/A (uncertainties near zero)") 

394 print("\n ✗ Uncertainty estimates not available") 

395 

396 print("=" * 50 + "\n")