.. _balance_wheel_complication:

==========================================
Complication III: Balance Wheel
==========================================

   *Neural SFS Inference via Differentiable Diffusion*

The Mechanism at a Glance
==========================

Mainspring inverts simulations to infer full ARGs. Escapement uses the coalescent
likelihood on raw genotypes to infer genealogies and :math:`N_e(t)`. Both operate on
**sequence-level** data -- they see every site, every haplotype.

Balance Wheel takes a different path: it operates on the **Site Frequency Spectrum
(SFS)**, the same summary statistic that dadi and moments use. The SFS compresses the
genome into a histogram of allele frequencies -- a massive dimensionality reduction
(millions of sites :math:`\to` :math:`n - 1` counts for :math:`n` samples, or
:math:`n_1 \times n_2` entries for two populations). This compression discards spatial
information (LD, haplotype structure) but retains everything needed for demographic
inference under the Poisson Random Field model.

The question is: can we replace dadi's PDE solver and moments' ODE integrator with a
neural network, while keeping the same Poisson likelihood and the same demographic
parameters?

**Yes.** And the result is faster, differentiable end-to-end, and handles model classes
that are intractable for dadi/moments (continuous :math:`N_e(t)`, high-dimensional
joint SFS, complex multi-population topologies).

.. admonition:: The name

   The balance wheel is the oscillating component of a mechanical watch that
   regulates the timekeeping. It doesn't track every molecule of air or every
   vibration of the case -- it feels the aggregate force of the balance spring and
   responds with a precise oscillation. The SFS is the aggregate pressure of
   evolution on a genome: it doesn't track every haplotype or every recombination
   event, but it captures the net effect of demographic history on allele
   frequencies. Balance Wheel, like its namesake, works with this aggregate signal.

The three modules of Balance Wheel:

1. **The Demography Encoder** (the hairspring) -- Encodes population size histories,
   split times, and migration rates into a dense vector representation. For
   piecewise-constant demography: a Transformer over (time, size) pairs. For
   continuous demography: a neural ODE. For multi-population models: a GNN over
   the population tree.

2. **The SFS Predictor** (the balance wheel) -- A neural network that maps demographic
   embeddings directly to the expected SFS. Replaces dadi's PDE solver and moments'
   ODE integrator. Trained with moments/dadi as a teacher -- not on coalescent
   simulations. Output: :math:`M(\Theta) \in \mathbb{R}^{n-1}` (1D SFS) or
   :math:`\mathbb{R}^{(n_1-1) \times (n_2-1)}` (2D joint SFS).

3. **The Poisson Likelihood** (the impulse pin) -- The exact same Poisson
   log-likelihood that dadi and moments optimize. No approximation, no neural
   network. Given the predicted SFS and the observed SFS, compute the likelihood.
   Gradients flow through the SFS Predictor via backpropagation.

.. code-block:: text

   Observed SFS D ∈ Z^{n-1}
   Demographic parameters Θ
                      |
                      v
            +--------------------------+
            |  DEMOGRAPHY ENCODER      |
            |  Transformer / NeuralODE |
            |  over (time, size) pairs |
            |                          |
            |  Multi-pop: GNN over     |
            |  population tree         |
            +--------------------------+
                      |
                      v
            +--------------------------+
            |  SFS PREDICTOR           |
            |  (replaces PDE/ODE)      |
            |                          |
            |  Θ → M(Θ)               |
            |  Trained with moments    |
            |  as teacher              |
            +--------------------------+
                      |
                      v
            +--------------------------+
            |  POISSON LIKELIHOOD      |
            |  (same as dadi/moments)  |
            |                          |
            |  ℓ(Θ) = Σ [D_j ln M_j   |
            |        - M_j - ln(D_j!)] |
            +--------------------------+
                      |
                      v (optimize Θ, or sample via HMC)
            Demographic parameters
            with full posterior

.. admonition:: Prerequisites for this Complication

   Balance Wheel directly extends two Timepieces. Before starting, you should have
   worked through:

   - :ref:`dadi <dadi_timepiece>` -- the Wright-Fisher diffusion PDE and numerical
     SFS computation. Balance Wheel learns to approximate this solver.
   - :ref:`moments <moments_timepiece>` -- the moment-equation ODE system. Used as
     the teacher during training.
   - :ref:`momi2 <momi2_timepiece>` -- coalescent SFS computation for multi-population
     models. Alternative teacher for complex topologies.

   Familiarity with function approximation, knowledge distillation, and Hamiltonian
   Monte Carlo is helpful but not strictly required.

Chapters
========

.. toctree::
   :maxdepth: 2

   overview
   what_dadi_moments_compute
   architecture
   teacher_student_training
   posterior_inference
   multi_population
   comparison
