Coverage for src/pycse/sklearn/kfoldnn.py: 0.00%
84 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
1"""A K-fold Neural network model in jax.
3The idea of the k-fold model is that you train each neuron in the last layer on
4a different fold of data. Then, at inference time you get a distribution of
5predictions that you can use for uncertainty quantification.
7The main hyperparameter that affects the distribution is the fraction of data
8used. Empirically I find that a fraction of 0.1 works pretty well. Note that the
9neurons before the last layer all end up seeing all the data, it is only the
10last layer that sees different parts of the data. If you use a fraction of 1.0,
11then each neuron converges to the same result.
13There isn't currently an obvious way to choose a fraction that leads to the
14"right" UQ distribution. You can try many values and see what works best.
16Example usage:
18import jax
19import numpy as np
20import matplotlib.pyplot as plt
22key = jax.random.PRNGKey(19)
24x = np.linspace(0, 1, 100)[:, None]
25y = x**(1/3) + (1 + jax.random.normal(key, x.shape) * 0.05)
28from pycse.sklearn.kfoldnn import KfoldNN
29model = KfoldNN((1, 15, 25), xtrain=0.1)
31model.fit(x, y)
33model.report()
34print(model.score(x, y))
35model.plot(x, y, distribution=True);
37"""
39import os
40import jax
43from jax import jit
44import jax.numpy as np
45from jax import value_and_grad
46from jaxopt import LBFGS
47import matplotlib.pyplot as plt
48from sklearn.base import BaseEstimator, RegressorMixin
49from flax import linen as nn
51os.environ["JAX_ENABLE_X64"] = "True"
52jax.config.update("jax_enable_x64", True)
55class _NN(nn.Module):
56 """A flax neural network.
58 layers: a Tuple of integers. Each integer is the number of neurons in that
59 layer.
60 """
62 layers: tuple
64 @nn.compact
65 def __call__(self, x):
66 for i in self.layers[0:-1]:
67 x = nn.Dense(i)(x)
68 x = nn.swish(x)
70 # Linear last layer
71 x = nn.Dense(self.layers[-1])(x)
72 return x
75class KfoldNN(BaseEstimator, RegressorMixin):
76 """sklearn compatible model for a k-fold neural network."""
78 def __init__(self, layers, xtrain=0.1, seed=19):
79 """Initialize a k-fold nn.
81 args:
82 layers : tuple of integers for neurons in each layer
83 xtrain: fraction of data to use in each fold.
84 """
85 self.layers = layers
86 self.key = jax.random.PRNGKey(seed)
87 self.nn = _NN(layers)
88 self.xtrain = xtrain
90 def fit(self, X, y, **kwargs):
91 """Fit the kfold nn.
93 Args:
94 X : a 2d array of x values
95 y : an array of y values.
97 kwargs are passed to the LBGF solver.
98 """
99 # This allows retraining.
100 if not hasattr(self, "optpars"):
101 params = self.nn.init(self.key, X)
102 else:
103 params = self.optpars
105 last_layer = f"Dense_{len(self.layers) - 1}"
106 w = params["params"][last_layer]["kernel"].shape
107 N = w[-1] # number of functions in the last layer
109 folds = jax.random.permutation(
110 self.key, np.tile(np.arange(0, len(X))[:, None], N), axis=0, independent=True
111 ).T
113 # make a smooth, differentiable cutoff
114 fx = np.arange(0, len(X))
116 # We use fy to mask out the errors for the dataset we don't want
117 _y = len(X) / 2 * (fx - len(X) * self.xtrain)
118 fy = 1 - 0.5 * (np.tanh(_y / 2) + 1)
120 @jit
121 def objective(pars):
122 agge = 0
124 for i, fold in enumerate(folds):
125 # predict for a fold
126 P = self.nn.apply(pars, np.asarray(X)[fold])
127 errs = (P - y[fold])[:, i] * fy # errors for this fold
129 mae = np.mean(np.abs(errs)) # MAE for the fold
130 agge += mae
131 return agge
133 if "maxiter" not in kwargs:
134 kwargs["maxiter"] = 1500
136 if "tol" not in kwargs:
137 kwargs["tol"] = 1e-3
139 solver = LBFGS(fun=value_and_grad(objective), value_and_grad=True, **kwargs)
141 self.optpars, self.state = solver.run(params)
143 def report(self):
144 """Print the state variables."""
145 print(f"Iterations: {self.state.iter_num} Value: {self.state.value}")
147 def predict(self, X, return_std=False):
148 """Predict the model for X.
150 Args:
151 X: a 2d array of points to predict
152 return_std: Boolean, if true, return error estimate for each point.
154 Returns:
155 if return_std is False, the predictions, else (predictions, errors)
156 """
157 X = np.atleast_2d(X)
158 P = self.nn.apply(self.optpars, X)
160 if return_std:
161 return np.mean(P, axis=1), np.std(P, axis=1)
162 else:
163 return np.mean(P, axis=1)
165 def __call__(self, X, return_std=False, distribution=False):
166 """Execute the model.
168 Args:
169 X: a 2d array to make predictions for.
170 return_std: Boolean, if true return errors for each point
171 distribution: Boolean, if true return the distribution, else the mean.
173 """
174 if not hasattr(self, "optpars"):
175 raise Exception("You need to fit the model first.")
177 # get predictions
178 P = self.nn.apply(self.optpars, X)
179 se = P.std(axis=1)
180 if not distribution:
181 P = P.mean(axis=1)
183 if return_std:
184 return (P, se)
185 else:
186 return P
188 def plot(self, X, y, distribution=False):
189 """Return a plot.
191 Args:
192 X: 2d array of data
193 y: corresponding y-values
194 distribution: Boolean, if true, plot the distribution of predictions.
195 """
196 P = self.nn.apply(self.optpars, X)
197 mp = P.mean(axis=1)
198 se = P.std(axis=1)
200 plt.plot(X, y, "b.", label="data")
201 plt.plot(X, mp, label="mean")
202 plt.plot(X, mp + 2 * se, "k--")
203 plt.plot(X, mp - 2 * se, "k--", label="+/- 2sd")
204 if distribution:
205 plt.plot(X, P, alpha=0.2)
206 plt.xlabel("X")
207 plt.ylabel("y")
208 plt.legend()
209 return plt.gcf()