Coverage for src/pycse/pyroxy.py: 0.00%
94 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
1"""pyroxy - a surrogate decorator
3TODO: What about train / test splits?
5TODO: what about a random fraction of function values instead of surrogate?
7This is proof of concept code and it is not obviously the best approach. A
8notable limitation is that pickle and joblib cannot save this. It works ok with
9dill so far.
11"""
13import numpy as np
14import matplotlib.pyplot as plt
15from sklearn.exceptions import NotFittedError
16import dill
19class _Surrogate:
20 def __init__(self, func, model, tol=1, max_calls=-1, verbose=False):
21 """Initialize a Surrogate function.
23 Parameters
24 ----------
26 func : Callable
28 Function that takes one argument
30 model : sklearn model
32 The model must be able to return std errors.
34 tol : float optional, default=1
36 Tolerance to use the surrogate. If the predicted error is less than
37 this we use the surrogate, otherwise use the true function and
38 retrain.
40 max_calls : int, default=-1,
42 Maximum number of calls to allow. An exception is raised if you exceed
43 this. -1 means no limit.
45 verbose : Boolean optional, default=False
46 If truthy, output is more verbose.
48 Returns
49 -------
50 return
52 """
53 self.func = func
54 self.model = model
55 self.tol = tol
56 self.max_calls = max_calls
57 self.verbose = verbose
58 self.xtrain = None
59 self.ytrain = None
61 self.ntrain = 0
62 self.surrogate = 0
63 self.func_calls = 0
65 def add(self, X):
66 """Get data for X, add it and retrain.
67 Use this to bypass the logic for using the surrogate.
68 """
69 if (self.max_calls > 0) and (self.func_calls + 1) > self.max_calls:
70 raise Exception("Max func calls {self.max_calls} will be exceeded.")
72 y = self.func(X)
73 self.func_calls += 1
75 # add it to the data. For now we add all the points
76 if self.xtrain is not None:
77 self.xtrain = np.concatenate([self.xtrain, X], axis=0)
78 self.ytrain = np.concatenate([self.ytrain, y])
79 else:
80 self.xtrain = X
81 self.ytrain = y
83 self.model.fit(self.xtrain, self.ytrain)
84 self.ntrain += 1
85 return y
87 def test(self, X):
88 """Run a test on X.
89 Runs true function on X, computes prediction errors.
91 Returns:
92 True if the actual errors are less than the tolerance.
93 """
94 y = self.func(X)
95 yp, ypse = self.model.predict(X, return_std=True)
97 errs = y - yp
99 if self.verbose:
100 print(
101 f"""Testing {X}
102 y = {y}
103 yp = {yp}
105 ypse = {ypse}
106 ypse < tol = {np.abs(ypse) < self.tol}
108 errs = {errs}
109 errs < tol = {np.abs(errs) < self.tol}
110 """
111 )
112 if (np.max(ypse) < self.tol) and (np.max(np.abs(errs)) < self.tol):
113 return True
114 else:
115 return False
117 def __call__(self, X):
118 """Try to use the surrogate to predict X. if the predicted error is
119 larger than self.tol, use the true function and retrain the surrogate.
121 """
122 try:
123 pf, se = self.model.predict(X, return_std=True)
125 # if we think it is accurate enough we return it
126 if np.all(se < self.tol):
127 self.surrogate += 1
128 return pf
129 else:
130 if self.verbose:
131 print(
132 f"For {X} -> {pf} err={se} is greater than {self.tol},",
133 "running true function and returning function values and retraining",
134 )
136 if (self.max_calls > 0) and (self.func_calls + 1) > self.max_calls:
137 raise Exception(f"Max func calls ({self.max_calls}) will be exceeded.")
138 # Get the true value(s)
139 y = self.func(X)
140 self.func_calls += 1
142 # add it to the data. For now we add all the points
143 self.xtrain = np.concatenate([self.xtrain, X], axis=0)
144 self.ytrain = np.concatenate([self.ytrain, y])
146 self.model.fit(self.xtrain, self.ytrain)
147 self.ntrain += 1
148 pf, se = self.model.predict(X, return_std=True)
149 return y
151 except (AttributeError, NotFittedError):
152 if self.verbose:
153 print(f"Running {X} to initialize the model.")
154 y = self.func(X)
156 self.xtrain = X
157 self.ytrain = y
159 self.model.fit(X, y)
160 self.ntrain += 1
161 return y
163 def plot(self):
164 """Generate a parity plot of the surrogate.
165 Shows approximate 95% uncertainty interval in shaded area.
166 """
168 yp, se = self.model.predict(self.xtrain, return_std=True)
170 # sort these so the points are plotted sequentially in order
171 sind = np.argsort(self.ytrain.flatten())
172 y = self.ytrain.flatten()[sind]
173 yp = yp.flatten()[sind]
174 se = se.flatten()[sind]
176 p = plt.plot(y, yp, "b.")
177 plt.fill_between(
178 y,
179 yp + 2 * se,
180 yp - 2 * se,
181 alpha=0.2,
182 )
183 plt.xlabel("Known y-values")
184 plt.ylabel("Predicted y-values")
185 plt.title(f"R$^2$ = {self.model.score(self.xtrain, self.ytrain)}")
186 return p
188 def __str__(self):
189 """A string representation."""
191 yp, ypse = self.model.predict(self.xtrain, return_std=True)
193 errs = self.ytrain - yp
195 """Returns a string representation of the surrogate."""
196 return f"""{len(self.xtrain)} data points obtained.
197 The model was fitted {self.ntrain} times.
198 The surrogate was successful {self.surrogate} times.
200 model score: {self.model.score(self.xtrain, self.ytrain)}
201 Errors:
202 MAE: {np.mean(np.abs(errs))}
203 RMSE: {np.sqrt(np.mean(errs**2))}
204 (tol = {self.tol})
206 """
208 def dump(self, fname="model.pkl"):
209 """Save the current surrogate to fname."""
210 with open(fname, "wb") as f:
211 f.write(dill.dumps(self))
213 return fname
216def Surrogate(function=None, *, model=None, tol=1, verbose=False, max_calls=-1):
217 """Function Wrapper for _Surrogate class
219 This allows me to use the class decorator with arguments.
221 """
223 def wrapper(function):
224 return _Surrogate(function, model=model, tol=tol, verbose=verbose, max_calls=max_calls)
226 return wrapper
229# This seems clunky, but I want this to have the syntax:
230# Surrogate.load(fname)
233def load(fname="model.pkl"):
234 """Load a surrogate from fname."""
235 with open(fname, "rb") as f:
236 return dill.loads(f.read())
239Surrogate.load = load