.. _balance_wheel_what_they_compute:

=============================================
What dadi and moments Actually Compute
=============================================

   *Before you can replace a movement, you must understand every wheel, every
   jewel, every beat of the escapement. The apprentice who shortcuts this step
   builds a watch that looks right but keeps the wrong time.*

Balance Wheel replaces the forward SFS computation of
:ref:`dadi <dadi_timepiece>` and :ref:`moments <moments_timepiece>` with a
neural network. To understand what we are replacing -- and why -- we need to
dissect the classical methods at the level of their PDEs and ODEs. This chapter
is that dissection.


The Shared Framework
======================

All three classical methods -- dadi, moments, and :ref:`momi2 <momi2_timepiece>`
-- solve the same inference problem:

.. math::

   \hat{\Theta} = \arg\max_\Theta \;\ell(\Theta), \qquad
   \ell(\Theta) = \sum_{j=1}^{n-1} \left[ D_j \ln M_j(\Theta) - M_j(\Theta)
   \right]

where :math:`D_j` is the observed SFS count at frequency bin :math:`j` and
:math:`M_j(\Theta)` is the expected count under demographic model
:math:`\Theta`. The constant :math:`-\ln(D_j!)` is dropped because it does not
depend on :math:`\Theta`.

The likelihood is the Poisson log-likelihood of the observed SFS given the
expected SFS. It is the same whether the expected SFS is computed by a PDE
solver, an ODE integrator, a coalescent computation, or a neural network.

The three methods differ only in how they compute
:math:`\mathbf{M}(\Theta) = (M_1(\Theta), \ldots, M_{n-1}(\Theta))`:

.. list-table:: Three paths to the expected SFS
   :header-rows: 1
   :widths: 15 35 25 25

   * - Method
     - Approach
     - Intermediate quantity
     - Cost per evaluation
   * - :ref:`dadi <dadi_timepiece>`
     - Solve Wright-Fisher diffusion PDE
     - Allele frequency density :math:`\phi(x, t)`
     - :math:`O(G^k)` where :math:`G` = grid size
   * - :ref:`moments <moments_timepiece>`
     - Integrate ODE for SFS entries
     - SFS vector :math:`\Phi(t)` directly
     - :math:`O(n^k)` per step
   * - :ref:`momi2 <momi2_timepiece>`
     - Coalescent computation of branch lengths
     - Expected branch lengths per frequency class
     - :math:`O(n^2)` per population pair


dadi: The Wright-Fisher Diffusion PDE
========================================

dadi's approach begins with the diffusion approximation to the Wright-Fisher
model (see :ref:`diffusion_approximation` for the prerequisite derivation). For
a single population with variable size :math:`\nu(t) = N_e(t) / N_{\text{ref}}`,
the allele frequency density :math:`\phi(x, t)` evolves according to:

.. math::

   \frac{\partial \phi}{\partial t} = \frac{1}{2\nu(t)} \cdot
   \frac{\partial^2}{\partial x^2}\!\left[x(1-x)\,\phi\right]

The initial condition is the equilibrium spectrum for a population of constant
size :math:`\nu_0`:

.. math::

   \phi(x, 0) = \frac{\theta}{x(1-x)}

where :math:`\theta = 4 N_{\text{ref}} \mu` is the population-scaled mutation
rate. The boundary conditions enforce no flux at :math:`x = 0` and :math:`x = 1`
(alleles cannot spontaneously appear or fix through drift alone; those transitions
are handled by the mutation terms).

From :math:`\phi(x, t)` to the SFS
-------------------------------------

Once the PDE is solved to the present time :math:`t_{\text{present}}`, the
expected SFS is obtained by integrating the frequency density against binomial
sampling weights:

.. math::

   M_j = \theta \cdot L \cdot \int_0^1 \binom{n}{j} x^j (1-x)^{n-j} \,
   \phi(x, t_{\text{present}}) \, dx

This integral maps the continuous allele frequency density to the discrete
frequency bins of a sample of size :math:`n`. The binomial coefficient accounts
for the probability of sampling exactly :math:`j` derived alleles from a site
at population frequency :math:`x`.

Numerical discretization
--------------------------

dadi discretizes the frequency axis on a grid of :math:`G` points
:math:`0 = x_0 < x_1 < \ldots < x_G = 1` and the time axis into
piecewise-constant epochs. Within each epoch, the population size is constant,
and the PDE is solved by the Crank-Nicolson finite-difference scheme:

.. math::

   \phi^{t+1} = \left(I - \frac{\Delta t}{2} A\right)^{-1}
   \left(I + \frac{\Delta t}{2} A\right) \phi^t

where :math:`A` is the tridiagonal matrix encoding the
:math:`\frac{1}{2\nu}\frac{\partial^2}{\partial x^2}[x(1-x) \cdot]` operator.
Each time step requires solving a tridiagonal system -- :math:`O(G)` per step.

For :math:`k` populations, the frequency axis becomes a :math:`k`-dimensional
grid with :math:`G^k` points. The PDE gains cross-population terms for
migration and population splits. The cost per time step becomes :math:`O(G^k)`,
which is the fundamental bottleneck.

.. code-block:: python

   import numpy as np

   def crank_nicolson_step(phi, dx, dt, nu):
       """One Crank-Nicolson step of the 1D diffusion PDE."""
       G = len(phi)
       x = np.linspace(0, 1, G)
       diffusion = x * (1 - x) / (2 * nu)

       diag = np.zeros(G)
       upper = np.zeros(G - 1)
       lower = np.zeros(G - 1)
       for i in range(1, G - 1):
           d2 = diffusion[i] / dx**2
           diag[i] = -2 * d2
           if i < G - 1:
               upper[i] = d2
           if i > 0:
               lower[i - 1] = d2

       A = np.diag(diag) + np.diag(upper, 1) + np.diag(lower, -1)
       lhs = np.eye(G) - 0.5 * dt * A
       rhs = (np.eye(G) + 0.5 * dt * A) @ phi
       return np.linalg.solve(lhs, rhs)

.. admonition:: Why piecewise-constant demography

   The Crank-Nicolson scheme requires :math:`\nu` to be constant within each
   time step. dadi approximates continuous demography by piecewise-constant
   functions -- constant within each epoch, with instantaneous jumps between
   epochs. The number of epochs is a user-specified modeling choice. Too few
   epochs underfit the true demography; too many create identifiability
   issues. This is a limitation that Balance Wheel eliminates by learning a
   smooth mapping that can handle continuously varying :math:`N_e(t)`.


moments: ODE for the SFS Directly
====================================

While dadi works with the full frequency density :math:`\phi(x, t)` and then
integrates to get the SFS, :ref:`moments <moments_timepiece>` takes a shortcut:
it derives ODEs for the SFS entries themselves.

The key insight is that the SFS entries
:math:`\Phi_j = E[D_j]` are the :math:`j`-th moments of the frequency
distribution with respect to binomial sampling. The diffusion PDE for
:math:`\phi(x, t)` implies an ODE for each :math:`\Phi_j`:

.. math::

   \frac{d\Phi_j}{dt} = \underbrace{\frac{1}{2\nu(t)} \left[
   (j+1)(n-j) \Phi_{j+1} - 2j(n-j) \Phi_j + j(n-j-1) \Phi_{j-1}
   \right]}_{\text{genetic drift}}
   + \underbrace{\theta \cdot \delta_{j,1}}_{\text{mutation}}

where the drift operator is a tridiagonal linear transformation of the SFS
vector, and mutation introduces new variants at frequency :math:`1/n`
(corresponding to a single derived allele in the sample).

In matrix notation:

.. math::

   \frac{d\Phi}{dt} = \frac{1}{2\nu(t)} D \cdot \Phi + \theta \cdot \mathbf{e}_1

where :math:`D` is the :math:`(n-1) \times (n-1)` drift matrix and
:math:`\mathbf{e}_1` is the first standard basis vector.

Advantages over dadi
-----------------------

1. **No frequency grid.** moments works directly with the :math:`n - 1` SFS
   entries, avoiding the :math:`G`-point frequency grid. For small sample sizes,
   :math:`n - 1 \ll G`.

2. **Sparse operators.** The drift matrix :math:`D` is tridiagonal. Each ODE
   step is :math:`O(n)` for one population, versus :math:`O(G)` for dadi.
   For :math:`k` populations, the SFS has :math:`\prod_i (n_i - 1)` entries and
   the drift operator is sparse -- cost is :math:`O(n^k)` but with a much
   smaller constant than dadi's :math:`O(G^k)`.

3. **Automatic differentiation.** moments can differentiate through the ODE
   solver to compute exact gradients :math:`\nabla_\Theta \mathbf{M}(\Theta)`,
   avoiding the finite-difference approximation that dadi uses.

.. code-block:: python

   import numpy as np
   from scipy.integrate import solve_ivp

   def drift_matrix(n):
       """Tridiagonal drift matrix for 1D SFS of sample size n."""
       size = n - 1
       D = np.zeros((size, size))
       for j in range(size):
           jj = j + 1
           if j > 0:
               D[j, j - 1] = jj * (n - jj - 1)
           D[j, j] = -2 * jj * (n - jj)
           if j < size - 1:
               D[j, j + 1] = (jj + 1) * (n - jj)
       return D

   def integrate_sfs(nu_func, times, n, theta):
       """Integrate the SFS ODE through a series of epochs."""
       D = drift_matrix(n)
       size = n - 1
       mutation = np.zeros(size)
       mutation[0] = theta

       def rhs(t, phi, nu):
           return (1.0 / (2 * nu)) * D @ phi + mutation

       phi = np.zeros(size)
       phi[0] = theta  # initial condition: equilibrium singleton count

       for i in range(len(times) - 1):
           nu = nu_func(times[i])
           dt = times[i + 1] - times[i]
           sol = solve_ivp(lambda t, y: rhs(t, y, nu),
                           [0, dt], phi, method='RK45')
           phi = sol.y[:, -1]

       return phi

The remaining bottleneck
--------------------------

For :math:`k` populations, the joint SFS has :math:`\prod_{i=1}^k (n_i - 1)`
entries. The drift operator for each population acts along one axis of the
:math:`k`-dimensional tensor. The cost per ODE step is :math:`O(n^k)` (or more
precisely, :math:`O(\prod n_i)` times the cost of applying the drift operator
along each axis). For :math:`k = 3` with :math:`n = 20`:

.. math::

   \text{SFS entries:} \quad 19 \times 19 \times 19 = 6{,}859

Each ODE step requires applying the drift operator along each of the three axes
-- a total of :math:`O(3 \times 19 \times 19^3) \approx 10^5` operations. With
hundreds of ODE steps per epoch and multiple epochs, a single SFS evaluation
can take minutes.


momi2: The Coalescent Path
=============================

:ref:`momi2 <momi2_timepiece>` takes a completely different approach. Instead
of solving a PDE or ODE in forward time (present :math:`\to` past), it uses
coalescent theory to compute the expected SFS from the expected branch lengths
of the genealogy.

The expected number of SNPs at frequency :math:`j/n` is:

.. math::

   M_j = \mu \cdot L \cdot E\!\left[\sum_{e \in \text{edges}} b(e) \cdot
   \mathbf{1}[\text{desc}(e) = j]\right]

where :math:`b(e)` is the branch length of edge :math:`e` (in generations) and
:math:`\text{desc}(e)` is the number of descendant leaves below :math:`e`. This
says: the expected SFS at frequency :math:`j` is proportional to the total
expected branch length subtending exactly :math:`j` leaves.

momi2 computes these expected branch lengths using a tensor machinery that
propagates through population splits, merges, and migration events (see
:ref:`momi2_tensor`). For multi-population models, this can be more efficient
than the PDE/ODE approach because the cost scales with the number of populations
and their relationships, not with the grid resolution or sample size.

.. admonition:: Three paths, one destination

   dadi, moments, and momi2 all produce the same :math:`\mathbf{M}(\Theta)` (up
   to numerical precision). They are three different algorithms for the same
   mathematical quantity. Balance Wheel adds a fourth path -- the neural path --
   that is faster than all three but requires a one-time training cost.


Computational Bottlenecks
============================

Why do these methods scale poorly, and what exactly limits them?

.. list-table:: Scaling bottlenecks
   :header-rows: 1
   :widths: 14 22 22 22 20

   * - Method
     - 1-pop cost
     - 2-pop cost
     - 3-pop cost
     - Wall-clock (3-pop, n=20)
   * - dadi
     - :math:`O(G)`
     - :math:`O(G^2)`
     - :math:`O(G^3)`
     - Hours (if :math:`G \geq 40`)
   * - moments
     - :math:`O(n)`
     - :math:`O(n^2)`
     - :math:`O(n^3)`
     - ~60 s
   * - momi2
     - :math:`O(n)`
     - :math:`O(n^2)`
     - :math:`O(n^3)`
     - ~30 s (coalescent)
   * - Balance Wheel
     - :math:`O(1)`
     - :math:`O(1)`
     - :math:`O(1)`
     - **~0.1 ms**

The exponential scaling in :math:`k` is the fundamental problem. For four or
more populations, even moments becomes impractical. dadi's authors recommend
:math:`k \leq 3` and warn that :math:`k = 3` already requires hours.

The gradient computation compounds the issue:

- **dadi**: uses finite differences. Each gradient evaluation requires
  :math:`2 \times |\Theta|` forward SFS evaluations (two perturbations per
  parameter). For 10 parameters, that is 20 PDE solves per gradient step.

- **moments**: uses automatic differentiation through the ODE solver, requiring
  only one forward pass plus one backward pass. But the backward pass through
  a stiff ODE can be numerically unstable, and the memory cost scales with the
  number of ODE steps (unless adjoint methods are used).

- **momi2**: computes gradients analytically through its tensor machinery. Fast
  for small models, but the tensor contractions become expensive for complex
  topologies.

For Balance Wheel, the gradient is a single backpropagation through a small
neural network -- :math:`O(1)` with a small constant, regardless of the number
of populations or the complexity of the demographic model.

.. math::

   \text{Gradient cost ratio} = \frac{\text{moments (finite diff)}}
   {\text{Balance Wheel (backprop)}} \approx
   \frac{2 \times |\Theta| \times 10\,\text{ms}}{0.1\,\text{ms}}
   \approx 2000\times

For HMC sampling, where each step requires ~10 gradient evaluations, this
translates to 200 ms per HMC step (moments) versus 1 ms (Balance Wheel). The
difference between impractical and routine.


What Balance Wheel Must Learn
================================

To summarize, Balance Wheel's SFS Predictor must learn a mapping
:math:`\Theta \to \mathbf{M}(\Theta)` that captures:

1. **The effect of genetic drift** on allele frequencies. Bottlenecks shift the
   SFS toward rare variants (excess singletons). Expansions flatten it. The
   timing and magnitude of size changes produce characteristic distortions.

2. **The interaction of multiple populations.** Migration homogenizes the joint
   SFS. Splits create private alleles. Ancient divergence produces shared
   low-frequency variants. The joint SFS encodes the full history of population
   interactions.

3. **The integration over the mutation process.** Each SFS entry is an integral
   (or sum) over the expected branch lengths of the genealogy, weighted by the
   mutation rate. The network must implicitly perform this integration.

4. **The numerical constraints.** All SFS entries must be positive
   (:math:`M_j > 0`), and the total expected number of segregating sites
   :math:`\sum_j M_j = \theta L \cdot E[T_{\text{total}}]` must be consistent
   with coalescent theory.

The remarkable fact is that a modest MLP (3--4 layers, 256 hidden units) can
learn this mapping to high accuracy across a wide range of demographic models.
The SFS is a smooth, low-dimensional function of :math:`\Theta` -- much smoother
than, say, the mapping from genotypes to genealogies that
:ref:`Mainspring <mainspring_complication>` must learn. This is why Balance
Wheel works with a simple architecture where Mainspring and Escapement require
Transformers and GNNs.
