Coverage for src/pycse/sklearn/kfoldnn.py: 0.00%

84 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 16:23 -0400

1"""A K-fold Neural network model in jax. 

2 

3The idea of the k-fold model is that you train each neuron in the last layer on 

4a different fold of data. Then, at inference time you get a distribution of 

5predictions that you can use for uncertainty quantification. 

6 

7The main hyperparameter that affects the distribution is the fraction of data 

8used. Empirically I find that a fraction of 0.1 works pretty well. Note that the 

9neurons before the last layer all end up seeing all the data, it is only the 

10last layer that sees different parts of the data. If you use a fraction of 1.0, 

11then each neuron converges to the same result. 

12 

13There isn't currently an obvious way to choose a fraction that leads to the 

14"right" UQ distribution. You can try many values and see what works best. 

15 

16Example usage: 

17 

18import jax 

19import numpy as np 

20import matplotlib.pyplot as plt 

21 

22key = jax.random.PRNGKey(19) 

23 

24x = np.linspace(0, 1, 100)[:, None] 

25y = x**(1/3) + (1 + jax.random.normal(key, x.shape) * 0.05) 

26 

27 

28from pycse.sklearn.kfoldnn import KfoldNN 

29model = KfoldNN((1, 15, 25), xtrain=0.1) 

30 

31model.fit(x, y) 

32 

33model.report() 

34print(model.score(x, y)) 

35model.plot(x, y, distribution=True); 

36 

37""" 

38 

39import os 

40import jax 

41 

42 

43from jax import jit 

44import jax.numpy as np 

45from jax import value_and_grad 

46from jaxopt import LBFGS 

47import matplotlib.pyplot as plt 

48from sklearn.base import BaseEstimator, RegressorMixin 

49from flax import linen as nn 

50 

51os.environ["JAX_ENABLE_X64"] = "True" 

52jax.config.update("jax_enable_x64", True) 

53 

54 

55class _NN(nn.Module): 

56 """A flax neural network. 

57 

58 layers: a Tuple of integers. Each integer is the number of neurons in that 

59 layer. 

60 """ 

61 

62 layers: tuple 

63 

64 @nn.compact 

65 def __call__(self, x): 

66 for i in self.layers[0:-1]: 

67 x = nn.Dense(i)(x) 

68 x = nn.swish(x) 

69 

70 # Linear last layer 

71 x = nn.Dense(self.layers[-1])(x) 

72 return x 

73 

74 

75class KfoldNN(BaseEstimator, RegressorMixin): 

76 """sklearn compatible model for a k-fold neural network.""" 

77 

78 def __init__(self, layers, xtrain=0.1, seed=19): 

79 """Initialize a k-fold nn. 

80 

81 args: 

82 layers : tuple of integers for neurons in each layer 

83 xtrain: fraction of data to use in each fold. 

84 """ 

85 self.layers = layers 

86 self.key = jax.random.PRNGKey(seed) 

87 self.nn = _NN(layers) 

88 self.xtrain = xtrain 

89 

90 def fit(self, X, y, **kwargs): 

91 """Fit the kfold nn. 

92 

93 Args: 

94 X : a 2d array of x values 

95 y : an array of y values. 

96 

97 kwargs are passed to the LBGF solver. 

98 """ 

99 # This allows retraining. 

100 if not hasattr(self, "optpars"): 

101 params = self.nn.init(self.key, X) 

102 else: 

103 params = self.optpars 

104 

105 last_layer = f"Dense_{len(self.layers) - 1}" 

106 w = params["params"][last_layer]["kernel"].shape 

107 N = w[-1] # number of functions in the last layer 

108 

109 folds = jax.random.permutation( 

110 self.key, np.tile(np.arange(0, len(X))[:, None], N), axis=0, independent=True 

111 ).T 

112 

113 # make a smooth, differentiable cutoff 

114 fx = np.arange(0, len(X)) 

115 

116 # We use fy to mask out the errors for the dataset we don't want 

117 _y = len(X) / 2 * (fx - len(X) * self.xtrain) 

118 fy = 1 - 0.5 * (np.tanh(_y / 2) + 1) 

119 

120 @jit 

121 def objective(pars): 

122 agge = 0 

123 

124 for i, fold in enumerate(folds): 

125 # predict for a fold 

126 P = self.nn.apply(pars, np.asarray(X)[fold]) 

127 errs = (P - y[fold])[:, i] * fy # errors for this fold 

128 

129 mae = np.mean(np.abs(errs)) # MAE for the fold 

130 agge += mae 

131 return agge 

132 

133 if "maxiter" not in kwargs: 

134 kwargs["maxiter"] = 1500 

135 

136 if "tol" not in kwargs: 

137 kwargs["tol"] = 1e-3 

138 

139 solver = LBFGS(fun=value_and_grad(objective), value_and_grad=True, **kwargs) 

140 

141 self.optpars, self.state = solver.run(params) 

142 

143 def report(self): 

144 """Print the state variables.""" 

145 print(f"Iterations: {self.state.iter_num} Value: {self.state.value}") 

146 

147 def predict(self, X, return_std=False): 

148 """Predict the model for X. 

149 

150 Args: 

151 X: a 2d array of points to predict 

152 return_std: Boolean, if true, return error estimate for each point. 

153 

154 Returns: 

155 if return_std is False, the predictions, else (predictions, errors) 

156 """ 

157 X = np.atleast_2d(X) 

158 P = self.nn.apply(self.optpars, X) 

159 

160 if return_std: 

161 return np.mean(P, axis=1), np.std(P, axis=1) 

162 else: 

163 return np.mean(P, axis=1) 

164 

165 def __call__(self, X, return_std=False, distribution=False): 

166 """Execute the model. 

167 

168 Args: 

169 X: a 2d array to make predictions for. 

170 return_std: Boolean, if true return errors for each point 

171 distribution: Boolean, if true return the distribution, else the mean. 

172 

173 """ 

174 if not hasattr(self, "optpars"): 

175 raise Exception("You need to fit the model first.") 

176 

177 # get predictions 

178 P = self.nn.apply(self.optpars, X) 

179 se = P.std(axis=1) 

180 if not distribution: 

181 P = P.mean(axis=1) 

182 

183 if return_std: 

184 return (P, se) 

185 else: 

186 return P 

187 

188 def plot(self, X, y, distribution=False): 

189 """Return a plot. 

190 

191 Args: 

192 X: 2d array of data 

193 y: corresponding y-values 

194 distribution: Boolean, if true, plot the distribution of predictions. 

195 """ 

196 P = self.nn.apply(self.optpars, X) 

197 mp = P.mean(axis=1) 

198 se = P.std(axis=1) 

199 

200 plt.plot(X, y, "b.", label="data") 

201 plt.plot(X, mp, label="mean") 

202 plt.plot(X, mp + 2 * se, "k--") 

203 plt.plot(X, mp - 2 * se, "k--", label="+/- 2sd") 

204 if distribution: 

205 plt.plot(X, P, alpha=0.2) 

206 plt.xlabel("X") 

207 plt.ylabel("y") 

208 plt.legend() 

209 return plt.gcf()