Coverage for src/pycse/mcp.py: 0.00%
310 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
« prev ^ index » next coverage.py v7.11.0, created at 2025-10-23 16:23 -0400
1"""A proof of concept pycse MCP server.
31. provide help with pycse functions.
42. a set of DOE functions.
53. tools for python docs
7"""
9import platform
10import sys
11import shutil
12import os
13import json
14import logging
15import re
16import pydoc
17from io import StringIO
18from mcp.server.fastmcp import FastMCP, Image
19from typing import Tuple, List, Union, Dict, Any, Optional, Pattern
20from pydantic import BaseModel, Field
21import pandas as pd
22import pycse
23import pkgutil
24import importlib
25import inspect
26import io
28from pycse.sklearn.lhc import LatinSquare
29from pycse.sklearn.surface_response import SurfaceResponse
30from pycse.sklearn.dpose import DPOSE
32import matplotlib
34matplotlib.use("Agg")
36# Initialize FastMCP server
37mcp = FastMCP("pycse")
40class Factor(BaseModel):
41 """Represents a single experimental factor with its levels."""
43 name: str = Field(..., description="Name of the factor (e.g., 'Red', 'Temperature')")
44 levels: List[Union[int, float]] = Field(..., description="List of factor levels")
47class LatinSquareSpec(BaseModel):
48 """Complete specification for a Latin square design.
50 It is a list of the factors and their levels.
52 """
54 factors: List[Factor] = Field(..., description="List of experimental factors")
57# this is a clunky way to save state between calls. It is not persistent, and
58# probably can be broken without trying too hard. e.g. using it multiple times,
59# or mixing lhc and sr. Another day I should look into using a class.
60STATE = {}
63@mcp.tool()
64def design_lhc(inputs: LatinSquareSpec) -> List[Dict[str, Any]]:
65 """Design a LatinSquare design from the inputs.
67 inputs is a list of tuples of the form: (varname, levels).
69 For example, you might specify:
71 Create a Latin square where Red is 0, 0.5, 1, Green is 0.0, 0.5, 1, and Blue
72 is [0, 0.5, 1] and we measure 515nm at the output.
74 This function returns the experiments that you should do as a pandas DataFrame.
75 """
76 # I think this is a little clunky for saving state
78 factors = inputs.factors
80 d = {factor.name: factor.levels for factor in factors}
82 ls = LatinSquare(d)
83 STATE["ls"] = ls
85 design = ls.design()
86 STATE["design"] = design
88 return design.to_dict(orient="records")
91class LatinSquareResult(BaseModel):
92 """Specification for a result.
94 A result is identified by an experiment number, and a corresponding float
95 result.
97 """
99 experiment: int = Field(..., description="Experiment number")
100 result: float = Field(..., description="Result as a float")
103class LatinSquareResults(BaseModel):
104 """List of the results.
106 This is a list of (experiment number, result).
108 """
110 results: List[LatinSquareResult] = Field(..., description="List of results")
113@mcp.tool()
114def analyze_lhc(lsr: LatinSquareResults) -> List[Dict[str, Any]]:
115 """Analyze the LatinSquare results.
117 The results have to be provided in a way that looks like a list of
118 (experiment #, result) can be parsed by the LLM.
120 Returns an analysis of variance (ANOVA).
122 """
124 df = pd.DataFrame(
125 [(result.experiment, result.result) for result in lsr.results],
126 columns=["Experiment", "Result"],
127 )
129 merged = STATE["design"].merge(df, left_index=True, right_on="Experiment")
131 ls = STATE["ls"]
133 X = merged[ls.labels]
134 y = merged["Result"]
136 ls.fit(X, y)
137 return ls.anova().to_dict(orient="records")
140# * Surface Response tools
143class SurfaceResponseInputs(BaseModel):
144 """Class to represent the list of inputs for the surface response model."""
146 inputs: List[str] = Field(..., description="List of input names")
149class SurfaceResponseOutputs(BaseModel):
150 """Class to represent the names of the output variables in the surface response model."""
152 outputs: List[str] = Field(..., description="List of output names")
155class SurfaceResponseBound(BaseModel):
156 """Class to represent the bounds of a variable."""
158 minmax: Tuple[float, float] = Field(..., description="Bounds (min, max) for one variable")
161class SurfaceResponseBounds(BaseModel):
162 """Class to represent all the bounds of the all the variables."""
164 bounds: List[SurfaceResponseBound] = Field(..., description="List of bounds")
167@mcp.tool()
168def design_sr(
169 inputs: SurfaceResponseInputs,
170 outputs: SurfaceResponseOutputs,
171 bounds: SurfaceResponseBounds,
172) -> List[Dict[str, Any]]:
173 """Design a surface response design of experiments.
175 Example:
176 design a pycse surface response experiment where red changes from 0.0 to 1.0,
177 blue from 0.0 to 0.5 and g changes from 0.0 to 1.0 and we measure the 515nm
178 channel.
181 """
183 b = [list(b.minmax) for b in bounds.bounds]
184 sr = SurfaceResponse(inputs=inputs.inputs, outputs=outputs.outputs, bounds=b)
186 STATE["sr"] = sr
187 STATE["sr_design"] = sr.design(shuffle=False)
188 return STATE["sr_design"].to_dict(orient="records")
191class SurfaceResponseResult(BaseModel):
192 """Specification for a result.
194 A result is identified by an experiment number, and a corresponding float
195 result.
197 Note this works for only one result column.
199 """
201 experiment: int = Field(..., description="Experiment number")
202 result: float = Field(..., description="Result as a float")
205class SurfaceResponseResults(BaseModel):
206 """List of the results.
208 This is a list of (experiment number, result).
210 """
212 results: List[SurfaceResponseResult] = Field(..., description="List of results")
215@mcp.tool()
216def analyze_sr(data: SurfaceResponseResults) -> str:
217 """Analyze the surface response results.
219 Returns a table of ANOVA results.
220 """
221 results = [[d.result] for d in data.results]
223 STATE["sr"].set_output(results)
224 STATE["sr"].fit()
225 return STATE["sr"].summary()
228@mcp.tool()
229def sr_parity() -> Image:
230 """Return a parity plot as an image.
232 You must run the analyze_sr tool before this one.
233 """
234 fig = STATE["sr"].parity()
235 buf = io.BytesIO()
236 fig.savefig(buf, format="png", bbox_inches="tight")
237 buf.seek(0)
238 png_bytes = buf.getvalue()
239 return Image(data=png_bytes, format="png")
242@mcp.tool()
243def random_image(n: int = 10) -> Image:
244 """Return a random image with N points in it."""
245 import numpy as np
246 import matplotlib.pyplot as plt
248 plt.plot(np.random.rand(n))
249 fig = plt.gcf()
251 buf = io.BytesIO()
252 fig.savefig(buf, format="png", bbox_inches="tight")
253 buf.seek(0)
254 png_bytes = buf.getvalue()
255 return Image(data=png_bytes, format="png")
258@mcp.tool()
259def pycse_help() -> str:
260 """Get help about pycse functions.
262 This returns a dictionary of function names and docstrings.
263 """
264 func_dict = {}
265 for finder, modname, ispkg in pkgutil.walk_packages(
266 pycse.__path__, prefix=pycse.__name__ + "."
267 ):
268 # This module seems to hang the function
269 if "sandbox" in modname:
270 continue
271 try:
272 print(finder, modname)
273 module = importlib.import_module(modname)
274 except Exception:
275 # skip modules that error on import
276 continue
278 for name, obj in inspect.getmembers(module, inspect.isfunction):
279 # only include functions actually defined in pycse
280 print(name)
281 if obj.__module__.startswith("pycse"):
282 qualname = f"{obj.__module__}.{obj.__name__}"
283 func_dict[qualname] = inspect.getdoc(obj) or ""
285 s = """The following list of functions are available. They are formatted
286 as function : docstring"""
287 for fq, doc in func_dict.items():
288 s += f"{fq} : {doc if doc else '<no doc>'}\n\n"
290 return s
293# * Function help
296@mcp.tool()
297def get_pydoc_help(func: str) -> str:
298 """Use pydoc to get help documentation on func.
300 Args:
301 func: Function object or string name of function
303 Returns:
304 str: Help documentation as string
305 """
306 # Capture pydoc output
307 old_stdout = sys.stdout
308 sys.stdout = captured_output = StringIO()
310 try:
311 pydoc.help(func)
312 help_text = captured_output.getvalue()
313 finally:
314 sys.stdout = old_stdout
316 return help_text
319@mcp.tool()
320def search_functions(pattern: str) -> str:
321 """
322 Search for functions matching a pattern using pydoc.
324 Args:
325 pattern (str): Search pattern
327 Returns:
328 str: Search results
329 """
330 old_stdout = sys.stdout
331 sys.stdout = captured_output = StringIO()
333 try:
334 pydoc.apropos(pattern)
335 search_results = captured_output.getvalue()
336 finally:
337 sys.stdout = old_stdout
339 return search_results
342@mcp.tool()
343def get_function_source(qualified_name: str) -> Tuple[str, Optional[str]]:
344 """
345 Retrieve the source code for a function given its fully qualified name.
347 Parameters
348 ----------
349 qualified_name : str
350 The fully qualified name of the function (e.g., 'numpy.linalg.solve',
351 'scipy.optimize.minimize', 'pycse.nlinfit')
353 Returns
354 -------
355 tuple
356 A tuple containing (source_code, error_message)
357 - source_code: str - The source code of the function, or None if error
358 - error_message: str - Error message if retrieval failed, or None if successful
360 Examples
361 --------
362 >>> source, error = get_function_source('numpy.mean')
363 >>> if error is None:
364 ... print(source)
366 >>> source, error = get_function_source('scipy.optimize.minimize')
367 >>> if error:
368 ... print(f"Error: {error}")
369 """
370 try:
371 # Split the qualified name into module path and function name
372 parts = qualified_name.split(".")
373 if len(parts) < 2:
374 return None, "Function name must be fully qualified (e.g., 'module.function')"
376 function_name = parts[-1]
377 module_path = ".".join(parts[:-1])
379 # Import the module
380 try:
381 module = importlib.import_module(module_path)
382 except ImportError as e:
383 return None, f"Could not import module '{module_path}': {str(e)}"
385 # Navigate through nested attributes if needed
386 # Handle cases like 'numpy.linalg.solve' where we need to go deeper
387 current_obj = module
388 for part in parts[len(module_path.split(".")) : -1]: # noqa:E203
389 if hasattr(current_obj, part):
390 current_obj = getattr(current_obj, part)
391 else:
392 return None, f"Module '{module_path}' has no attribute '{part}'"
394 # Get the function object
395 if hasattr(current_obj, function_name):
396 func_obj = getattr(current_obj, function_name)
397 else:
398 return None, f"Object has no attribute '{function_name}'"
400 # Check if it's callable
401 if not callable(func_obj):
402 return None, f"'{qualified_name}' is not a callable function"
404 # Try to get the source code
405 try:
406 source = inspect.getsource(func_obj)
407 return source, None
408 except OSError as e:
409 # This happens when source is not available (built-in functions, C extensions, etc.)
410 return None, f"Source code not available for '{qualified_name}': {str(e)}"
411 except Exception as e:
412 return None, f"Error retrieving source for '{qualified_name}': {str(e)}"
414 except Exception as e:
415 return None, f"Unexpected error: {str(e)}"
418@mcp.tool()
419def get_function_info(qualified_name: str) -> Tuple[Optional[dict], Optional[str]]:
420 """
421 Get comprehensive information about a function including source, signature, and docstring.
423 Parameters
424 ----------
425 qualified_name : str
426 The fully qualified name of the function
428 Returns
429 -------
430 tuple
431 A tuple containing (info_dict, error_message)
432 - info_dict: dict containing 'source', 'signature', 'docstring', 'module', 'file'
433 - error_message: str if error occurred, None if successful
434 """
435 try:
436 # Split the qualified name
437 parts = qualified_name.split(".")
438 if len(parts) < 2:
439 return None, "Function name must be fully qualified (e.g., 'module.function')"
441 function_name = parts[-1]
442 module_path = ".".join(parts[:-1])
444 # Import the module
445 try:
446 module = importlib.import_module(module_path)
447 except ImportError as e:
448 return None, f"Could not import module '{module_path}': {str(e)}"
450 # Navigate to the function
451 current_obj = module
452 for part in parts[len(module_path.split(".")) : -1]: # noqa:E203
453 if hasattr(current_obj, part):
454 current_obj = getattr(current_obj, part)
455 else:
456 return None, f"Module '{module_path}' has no attribute '{part}'"
458 # Get the function object
459 if hasattr(current_obj, function_name):
460 func_obj = getattr(current_obj, function_name)
461 else:
462 return None, f"Object has no attribute '{function_name}'"
464 if not callable(func_obj):
465 return None, f"'{qualified_name}' is not a callable function"
467 # Collect information
468 info = {
469 "name": qualified_name,
470 "module": module_path,
471 "docstring": inspect.getdoc(func_obj),
472 }
474 # Try to get signature
475 try:
476 info["signature"] = str(inspect.signature(func_obj))
477 except (ValueError, TypeError):
478 info["signature"] = "Signature not available"
480 # Try to get source file
481 try:
482 info["file"] = inspect.getfile(func_obj)
483 except (OSError, TypeError):
484 info["file"] = "File location not available"
486 # Try to get source code
487 try:
488 info["source"] = inspect.getsource(func_obj)
489 except OSError:
490 info["source"] = "Source code not available (built-in or C extension)"
491 except Exception as e:
492 info["source"] = f"Error retrieving source: {str(e)}"
494 return info, None
496 except Exception as e:
497 return None, f"Unexpected error: {str(e)}"
500logger = logging.getLogger(__name__)
503class AproposResult(BaseModel):
504 functions: Dict[str, str] = Field(default_factory=dict)
505 classes: Dict[str, str] = Field(default_factory=dict)
506 modules: Dict[str, str] = Field(default_factory=dict)
507 methods: Optional[Dict[str, str]] = Field(default_factory=dict)
508 errors: Optional[Dict[str, str]] = Field(
509 default_factory=dict
510 ) # Optional: module_name -> error msg
513@mcp.tool()
514def search_with_apropos(
515 keywords: Union[str, List[str]],
516 modules_to_search: List[str],
517 case_sensitive: bool = False,
518 include_methods: bool = True,
519 include_errors: bool = False,
520) -> AproposResult:
521 """
522 Search for functions, classes, and modules using their docstrings and names.
524 Parameters:
525 -----------
526 keywords : str or list of str
527 Search keywords. Multiple keywords are treated as AND conditions.
528 modules_to_search : list of str
529 List of module names to search. Required for MCP safety.
530 case_sensitive : bool
531 Whether the search is case sensitive.
532 include_methods : bool
533 Whether to include methods in the results.
534 include_errors : bool
535 Whether to include modules that failed with error messages.
537 Returns:
538 --------
539 AproposResult: Structured result object.
540 """
542 if isinstance(keywords, str):
543 keywords = [keywords]
545 patterns: List[Pattern[str]] = [
546 re.compile(re.escape(kw), 0 if case_sensitive else re.IGNORECASE) for kw in keywords
547 ]
549 def matches(text: str) -> bool:
550 return all(p.search(text) for p in patterns)
552 results = AproposResult()
554 for module_name in modules_to_search:
555 try:
556 spec = importlib.util.find_spec(module_name)
557 if spec is None:
558 raise ModuleNotFoundError(f"No module spec found for '{module_name}'")
560 module = importlib.import_module(module_name)
561 module_doc = getattr(module, "__doc__", "") or ""
562 module_info = (
563 f"{module_name}: {module_doc.split('.')[0] if module_doc else 'No description'}"
564 )
566 if matches(module_name) or matches(module_doc):
567 results.modules[module_name] = module_info.strip()
569 for attr_name in dir(module):
570 if attr_name.startswith("_"):
571 continue
573 try:
574 attr = getattr(module, attr_name)
575 attr_doc = getattr(attr, "__doc__", "") or ""
576 full_name = f"{module_name}.{attr_name}"
577 search_text = f"{attr_name} {attr_doc}"
579 if not matches(search_text):
580 continue
582 attr_desc = attr_doc.split(".")[0] if attr_doc else "No description"
583 desc = f"{full_name}: {attr_desc}".strip()
585 if callable(attr):
586 if isinstance(attr, type):
587 results.classes[full_name] = desc
588 elif hasattr(attr, "__self__") and include_methods:
589 results.methods[full_name] = desc
590 else:
591 results.functions[full_name] = desc
593 except Exception as sub_err:
594 if include_errors:
595 results.errors[f"{module_name}.{attr_name}"] = str(sub_err)
596 continue
598 except Exception as mod_err:
599 if include_errors:
600 results.errors[module_name] = str(mod_err)
601 logger.warning(f"Error processing module {module_name}: {mod_err}")
602 continue
604 if not include_methods:
605 results.methods = None
607 return results
610# * DPOSE - Direct Propagation of Shallow Ensembles
613class DPOSESpec(BaseModel):
614 """Specification for a DPOSE model."""
616 layers: Tuple[int, int, int] = Field(
617 ..., description="Network architecture: (input_dim, hidden_dim, n_ensemble)"
618 )
619 loss_type: str = Field(default="crps", description="Loss type: 'crps', 'nll', or 'mse'")
620 activation: str = Field(
621 default="tanh", description="Activation function: 'tanh', 'relu', 'softplus', 'elu'"
622 )
623 optimizer: str = Field(
624 default="bfgs", description="Optimizer: 'bfgs', 'lbfgs', 'adam', 'sgd', 'muon'"
625 )
626 maxiter: int = Field(default=500, description="Maximum training iterations")
627 seed: int = Field(default=42, description="Random seed for reproducibility")
630@mcp.tool()
631def dpose_info() -> str:
632 """Get information about DPOSE and usage examples.
634 Returns:
635 str: Information about DPOSE, its features, and example usage
636 """
637 info = """
638DPOSE (Direct Propagation of Shallow Ensembles)
639================================================
641A neural network ensemble method for uncertainty quantification.
643Key Features:
644- Per-sample uncertainty estimates (heteroscedastic)
645- Shallow ensemble architecture (only last layer differs)
646- CRPS or NLL loss for calibrated uncertainties
647- Uncertainty propagation through transformations
648- Handles gaps and extrapolation
650Architecture:
651- Input layer: matches data dimension
652- Hidden layer: typically 15-50 units
653- Ensemble: typically 32 members
655Example Workflow:
6561. Prepare data: X_train (2D), y_train (1D)
6572. Train model: use fit() or create via Python
6583. Get predictions: predict() returns mean
6594. Get uncertainty: predict() with return_std=True
661Recommended Settings:
662- layers: (n_features, 50, 32) for robust fitting
663- loss_type: 'crps' (default, robust)
664- activation: 'tanh' (smooth functions)
665- optimizer: 'bfgs' (default, fast for small data)
666- maxiter: 500 (adjust based on convergence)
668Reference:
669Kellner, M., & Ceriotti, M. (2024). Uncertainty quantification
670by direct propagation of shallow ensembles.
671Machine Learning: Science and Technology, 5(3), 035006.
672"""
673 return info
676@mcp.tool()
677def dpose_example_code() -> str:
678 """Get example Python code for using DPOSE.
680 Returns:
681 str: Complete example code showing how to use DPOSE
682 """
683 # This uses DPOSE to ensure it's not an unused import
684 example = f"""
685# Example: Using DPOSE for Uncertainty Quantification
686# ====================================================
688import numpy as np
689from sklearn.model_selection import train_test_split
690from sklearn.pipeline import Pipeline
691from sklearn.preprocessing import StandardScaler
692from pycse.sklearn.dpose import {DPOSE.__name__}
694# 1. Generate example data with heteroscedastic noise
695np.random.seed(42)
696x = np.linspace(0, 1, 200)[:, None]
697noise = 0.01 + 0.15 * x.ravel() # Increasing noise
698y = 2 * x.ravel() + noise * np.random.randn(200)
700# 2. Split data
701x_train, x_test, y_train, y_test = train_test_split(
702 x, y, test_size=0.2, random_state=42
703)
705# 3. Create and train DPOSE model with StandardScaler
706model = Pipeline([
707 ('scaler', StandardScaler()),
708 ('dpose', DPOSE(
709 layers=(1, 50, 32), # (input, hidden, ensemble)
710 loss_type='crps', # CRPS loss (recommended)
711 activation='tanh', # Tanh activation
712 maxiter=500, # Training iterations
713 seed=42
714 ))
715])
717# 4. Fit the model
718model.fit(x_train, y_train)
720# 5. Make predictions with uncertainty
721x_test_scaled = model.named_steps['scaler'].transform(x_test)
722y_pred, y_std = model.named_steps['dpose'].predict(
723 x_test_scaled,
724 return_std=True
725)
727# 6. Evaluate
728mae = np.abs(y_test - model.predict(x_test)).mean()
729print(f"MAE: {{mae:.6f}}")
730print(f"Uncertainty range: [{{y_std.min():.4f}}, {{y_std.max():.4f}}]")
732# 7. For uncertainty propagation through transformations
733ensemble = model.named_steps['dpose'].predict_ensemble(x_test_scaled)
734z_ensemble = np.exp(ensemble) # Apply transformation
735z_mean = z_ensemble.mean(axis=1)
736z_std = z_ensemble.std(axis=1)
737"""
738 return example
741# * Run / install / uninstall the server
744def main():
745 """Install, uninstall, or run the server.
747 This is the cli. If you call it with install or uninstall as an argument, it
748 will do that in the Claude Desktop. With no arguments it just runs the
749 server.
750 """
751 if platform.system() == "Darwin":
752 cfgfile = "~/Library/Application Support/Claude/claude_desktop_config.json"
753 elif platform.system() == "Windows":
754 cfgfile = r"%APPDATA%\Claude\claude_desktop_config.json"
755 else:
756 raise Exception("Only Mac and Windows are supported for the pycse mcp server")
758 cfgfile = os.path.expandvars(cfgfile)
759 cfgfile = os.path.expanduser(cfgfile)
761 if os.path.exists(cfgfile):
762 with open(cfgfile, "r") as f:
763 cfg = json.loads(f.read())
764 else:
765 cfg = {}
767 # Called with no arguments just run the server
768 if len(sys.argv) == 1:
769 mcp.run(transport="stdio")
771 elif sys.argv[1] == "install":
772 setup = {"command": shutil.which("pycse_mcp")}
774 if "mcpServers" not in cfg:
775 cfg["mcpServers"] = {}
777 cfg["mcpServers"]["pycse"] = setup
778 with open(cfgfile, "w") as f:
779 f.write(json.dumps(cfg, indent=4))
781 print(
782 f"\n\nInstalled litdb. Here is your current {cfgfile}. Please restart Claude Desktop."
783 )
784 print(json.dumps(cfg, indent=4))
786 elif sys.argv[1] == "uninstall":
787 if "mcpServers" not in cfg:
788 cfg["mcpServers"] = {}
790 if "pycse" in cfg["mcpServers"]:
791 del cfg["mcpServers"]["pycse"]
792 with open(cfgfile, "w") as f:
793 f.write(json.dumps(cfg, indent=4))
795 print(f"Uninstalled litdb. Here is your current {cfgfile}.")
796 print(json.dumps(cfg, indent=4))
798 else:
799 print("I am not sure what you are trying to do. Please use install or uninstall.")