Coverage for src/pycse/pyroxy.py: 0.00%

94 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-10-23 16:23 -0400

1"""pyroxy - a surrogate decorator 

2 

3TODO: What about train / test splits? 

4 

5TODO: what about a random fraction of function values instead of surrogate? 

6 

7This is proof of concept code and it is not obviously the best approach. A 

8notable limitation is that pickle and joblib cannot save this. It works ok with 

9dill so far. 

10 

11""" 

12 

13import numpy as np 

14import matplotlib.pyplot as plt 

15from sklearn.exceptions import NotFittedError 

16import dill 

17 

18 

19class _Surrogate: 

20 def __init__(self, func, model, tol=1, max_calls=-1, verbose=False): 

21 """Initialize a Surrogate function. 

22 

23 Parameters 

24 ---------- 

25 

26 func : Callable 

27 

28 Function that takes one argument 

29 

30 model : sklearn model 

31 

32 The model must be able to return std errors. 

33 

34 tol : float optional, default=1 

35 

36 Tolerance to use the surrogate. If the predicted error is less than 

37 this we use the surrogate, otherwise use the true function and 

38 retrain. 

39 

40 max_calls : int, default=-1, 

41 

42 Maximum number of calls to allow. An exception is raised if you exceed 

43 this. -1 means no limit. 

44 

45 verbose : Boolean optional, default=False 

46 If truthy, output is more verbose. 

47 

48 Returns 

49 ------- 

50 return 

51 

52 """ 

53 self.func = func 

54 self.model = model 

55 self.tol = tol 

56 self.max_calls = max_calls 

57 self.verbose = verbose 

58 self.xtrain = None 

59 self.ytrain = None 

60 

61 self.ntrain = 0 

62 self.surrogate = 0 

63 self.func_calls = 0 

64 

65 def add(self, X): 

66 """Get data for X, add it and retrain. 

67 Use this to bypass the logic for using the surrogate. 

68 """ 

69 if (self.max_calls > 0) and (self.func_calls + 1) > self.max_calls: 

70 raise Exception("Max func calls {self.max_calls} will be exceeded.") 

71 

72 y = self.func(X) 

73 self.func_calls += 1 

74 

75 # add it to the data. For now we add all the points 

76 if self.xtrain is not None: 

77 self.xtrain = np.concatenate([self.xtrain, X], axis=0) 

78 self.ytrain = np.concatenate([self.ytrain, y]) 

79 else: 

80 self.xtrain = X 

81 self.ytrain = y 

82 

83 self.model.fit(self.xtrain, self.ytrain) 

84 self.ntrain += 1 

85 return y 

86 

87 def test(self, X): 

88 """Run a test on X. 

89 Runs true function on X, computes prediction errors. 

90 

91 Returns: 

92 True if the actual errors are less than the tolerance. 

93 """ 

94 y = self.func(X) 

95 yp, ypse = self.model.predict(X, return_std=True) 

96 

97 errs = y - yp 

98 

99 if self.verbose: 

100 print( 

101 f"""Testing {X} 

102 y = {y} 

103 yp = {yp} 

104 

105 ypse = {ypse} 

106 ypse < tol = {np.abs(ypse) < self.tol} 

107 

108 errs = {errs} 

109 errs < tol = {np.abs(errs) < self.tol} 

110 """ 

111 ) 

112 if (np.max(ypse) < self.tol) and (np.max(np.abs(errs)) < self.tol): 

113 return True 

114 else: 

115 return False 

116 

117 def __call__(self, X): 

118 """Try to use the surrogate to predict X. if the predicted error is 

119 larger than self.tol, use the true function and retrain the surrogate. 

120 

121 """ 

122 try: 

123 pf, se = self.model.predict(X, return_std=True) 

124 

125 # if we think it is accurate enough we return it 

126 if np.all(se < self.tol): 

127 self.surrogate += 1 

128 return pf 

129 else: 

130 if self.verbose: 

131 print( 

132 f"For {X} -> {pf} err={se} is greater than {self.tol},", 

133 "running true function and returning function values and retraining", 

134 ) 

135 

136 if (self.max_calls > 0) and (self.func_calls + 1) > self.max_calls: 

137 raise Exception(f"Max func calls ({self.max_calls}) will be exceeded.") 

138 # Get the true value(s) 

139 y = self.func(X) 

140 self.func_calls += 1 

141 

142 # add it to the data. For now we add all the points 

143 self.xtrain = np.concatenate([self.xtrain, X], axis=0) 

144 self.ytrain = np.concatenate([self.ytrain, y]) 

145 

146 self.model.fit(self.xtrain, self.ytrain) 

147 self.ntrain += 1 

148 pf, se = self.model.predict(X, return_std=True) 

149 return y 

150 

151 except (AttributeError, NotFittedError): 

152 if self.verbose: 

153 print(f"Running {X} to initialize the model.") 

154 y = self.func(X) 

155 

156 self.xtrain = X 

157 self.ytrain = y 

158 

159 self.model.fit(X, y) 

160 self.ntrain += 1 

161 return y 

162 

163 def plot(self): 

164 """Generate a parity plot of the surrogate. 

165 Shows approximate 95% uncertainty interval in shaded area. 

166 """ 

167 

168 yp, se = self.model.predict(self.xtrain, return_std=True) 

169 

170 # sort these so the points are plotted sequentially in order 

171 sind = np.argsort(self.ytrain.flatten()) 

172 y = self.ytrain.flatten()[sind] 

173 yp = yp.flatten()[sind] 

174 se = se.flatten()[sind] 

175 

176 p = plt.plot(y, yp, "b.") 

177 plt.fill_between( 

178 y, 

179 yp + 2 * se, 

180 yp - 2 * se, 

181 alpha=0.2, 

182 ) 

183 plt.xlabel("Known y-values") 

184 plt.ylabel("Predicted y-values") 

185 plt.title(f"R$^2$ = {self.model.score(self.xtrain, self.ytrain)}") 

186 return p 

187 

188 def __str__(self): 

189 """A string representation.""" 

190 

191 yp, ypse = self.model.predict(self.xtrain, return_std=True) 

192 

193 errs = self.ytrain - yp 

194 

195 """Returns a string representation of the surrogate.""" 

196 return f"""{len(self.xtrain)} data points obtained. 

197 The model was fitted {self.ntrain} times. 

198 The surrogate was successful {self.surrogate} times. 

199 

200 model score: {self.model.score(self.xtrain, self.ytrain)} 

201 Errors: 

202 MAE: {np.mean(np.abs(errs))} 

203 RMSE: {np.sqrt(np.mean(errs**2))} 

204 (tol = {self.tol}) 

205 

206 """ 

207 

208 def dump(self, fname="model.pkl"): 

209 """Save the current surrogate to fname.""" 

210 with open(fname, "wb") as f: 

211 f.write(dill.dumps(self)) 

212 

213 return fname 

214 

215 

216def Surrogate(function=None, *, model=None, tol=1, verbose=False, max_calls=-1): 

217 """Function Wrapper for _Surrogate class 

218 

219 This allows me to use the class decorator with arguments. 

220 

221 """ 

222 

223 def wrapper(function): 

224 return _Surrogate(function, model=model, tol=tol, verbose=verbose, max_calls=max_calls) 

225 

226 return wrapper 

227 

228 

229# This seems clunky, but I want this to have the syntax: 

230# Surrogate.load(fname) 

231 

232 

233def load(fname="model.pkl"): 

234 """Load a surrogate from fname.""" 

235 with open(fname, "rb") as f: 

236 return dill.loads(f.read()) 

237 

238 

239Surrogate.load = load