Metadata-Version: 2.2
Name: jax-metallib
Version: 0.9.4
Summary: JAX backend for Apple M series of chips
Keywords: jax,metal,mps,apple,gpu,machine-learning
License: Apache-2.0
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: MacOS
Classifier: Programming Language :: Python :: 3
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Project-URL: Homepage, https://github.com/erfanzar/jax-metallib
Project-URL: Repository, https://github.com/erfanzar/jax-metallib
Requires-Python: >=3.11
Requires-Dist: jax<0.10,>=0.9.0
Requires-Dist: jaxlib<0.10,>=0.9.0
Description-Content-Type: text/markdown

# jax-metallib

A [PJRT](https://openxla.org/xla/pjrt_integration) plugin that enables
[JAX](https://github.com/jax-ml/jax) to run on Apple Metal (MPS) GPUs on Apple
Silicon. It compiles StableHLO IR to Metal shaders via MPSGraph, giving JAX
programs GPU acceleration on M-series Macs.

> **Status:** Alpha (v0.9.1) -- API and op coverage are evolving.

## Requirements

- macOS 13+ on Apple Silicon (M1 / M2 / M3 / M4)
- Python 3.11+
- `jax` and `jaxlib` 0.9.x

## Install

```bash
pip install jax-metallib
```

Verify:

```bash
JAX_PLATFORMS=mps python -c "import jax; print(jax.devices())"
# [MpsDevice(id=0)]
```

## Build from source

```bash
brew install cmake ninja
git clone https://github.com/erfanzar/jax-metallib.git
cd jax-metallib
uv sync --all-groups
uv pip install -e .       # auto-bootstraps native deps on first build (~30 min)
```

To skip the automatic dependency bootstrap (if you manage deps yourself):

```bash
CMAKE_ARGS="-DJAX_SILICON_AUTO_SETUP_DEPS=OFF" uv pip install -e .
```

### Native dependencies

`scripts/setup_deps.sh` fetches and builds:

| Dependency  | Version / Commit  |
| ----------- | ----------------- |
| LLVM + MLIR | XLA pin `bb760b0` |
| StableHLO   | `127d2f2`         |
| Abseil      | 20250127.0        |
| Protobuf    | 29.3              |

These are installed to `~/.local/jax-silicon-deps` by default.

## Usage

The plugin registers as the `mps` platform in JAX:

```python
import jax
import jax.numpy as jnp

# With JAX_PLATFORMS=mps (or setting jax.config)
x = jnp.ones((1024, 1024))
y = x @ x  # runs on Metal GPU
```

### Environment variables

| Variable                   | Description                                          |
| -------------------------- | ---------------------------------------------------- |
| `JAX_PLATFORMS=mps`        | Select the MPS backend                               |
| `JAX_SILICON_LIBRARY_PATH` | Override path to `libpjrt_plugin_silicon.dylib`      |
| `JAX_MPS_LIBRARY_PATH`     | Legacy alias for the above                           |
| `MPS_LOG_LEVEL=0..3`       | Logging verbosity (0=error, 1=warn, 2=info, 3=debug) |

## Supported operations

100+ StableHLO operations are implemented across these categories:

| Category       | Examples                                                                   |
| -------------- | -------------------------------------------------------------------------- |
| Unary          | `tanh`, `exp`, `log`, `sin`, `cos`, `sqrt`, `erf`, `abs`, `sign`           |
| Binary         | `add`, `subtract`, `multiply`, `divide`, `dot`, comparisons                |
| Reductions     | `reduce_sum`, `reduce_max`, `reduce_min`, `argmax`, `argmin`               |
| Shape          | `reshape`, `transpose`, `slice`, `pad`, `concatenate`, `gather`, `scatter` |
| Convolution    | `conv_general_dilated` with arbitrary padding/dilation                     |
| Linear algebra | `matmul`, `cholesky`, `qr`, `svd`, `triangular_solve`                      |
| FFT            | `fft`, `rfft`, `ifft`, `irfft`                                             |
| Random         | Threefry / Philox RNG                                                      |
| Sort           | `sort`, `top_k`                                                            |
| Control flow   | `cond` (if/else), `while_loop`, `scan`                                     |
| Bitwise        | `and`, `or`, `xor`, `shift_left`, `shift_right`                            |

Encountering an unsupported op prints a diagnostic with a link to file a feature
request.

## Testing

```bash
uv run pytest                   # compare CPU vs MPS (default)
JAX_TEST_MODE=mps uv run pytest # MPS only
JAX_TEST_MODE=cpu uv run pytest # CPU only
```

The test suite covers value correctness, gradient accuracy, and includes
integration tests with [Flax](https://github.com/google/flax) and
[NumPyro](https://github.com/pyro-ppl/numpyro).

## How it works

```text
JAX Python code
      |
StableHLO IR (MLIR)
      |
stablehlo_parser.mm   -- parses IR, looks up ops in the registry
      |
MPSGraph operations   -- builds a Metal compute graph
      |
Metal command buffer  -- compiled & dispatched to GPU
      |
Device memory result
```

The PJRT C API (`pjrt_api.cc`) exposes client, device, buffer, and executable
abstractions that JAX expects. Each StableHLO op is registered in
`src/pjrt_plugin/ops/` and mapped to the corresponding MPSGraph method.

## Repository layout

```text
src/
  jax_plugins/silicon/    Python entrypoint (plugin registration)
  pjrt_plugin/            C++/Obj-C++ PJRT backend
    ops/                  Op implementations (~100+ ops)
    stablehlo_parser.mm   StableHLO IR -> MPSGraph compiler
    mps_client.mm         Metal device & command queue management
    mps_executable.mm     Executable compilation & dispatch
tests/
  test_ops.py             Parametrized op tests
  configs/                Per-category test configurations
scripts/
  setup_deps.sh           One-time native dependency bootstrap
  release.sh              Release automation
```

## Contributing

See [CONTRIBUTING.md](CONTRIBUTING.md). The short version:

1. `brew install cmake ninja && ./scripts/setup_deps.sh && uv sync --all-groups`
2. `uv pip install -e . && pre-commit install`
3. Add your op in `src/pjrt_plugin/ops/`, add a test in `tests/configs/`, rebuild, and run `uv run pytest`.

## License

Apache-2.0. See [LICENSE](LICENSE).

This project is a derivative of [`tillahoffmann/jax-mps`](https://github.com/tillahoffmann/jax-mps).
