What dadi and moments Actually Compute
Before you can replace a movement, you must understand every wheel, every jewel, every beat of the escapement. The apprentice who shortcuts this step builds a watch that looks right but keeps the wrong time.
Balance Wheel replaces the forward SFS computation of dadi and moments with a neural network. To understand what we are replacing – and why – we need to dissect the classical methods at the level of their PDEs and ODEs. This chapter is that dissection.
dadi: The Wright-Fisher Diffusion PDE
dadi’s approach begins with the diffusion approximation to the Wright-Fisher model (see The Diffusion Approximation for the prerequisite derivation). For a single population with variable size \(\nu(t) = N_e(t) / N_{\text{ref}}\), the allele frequency density \(\phi(x, t)\) evolves according to:
The initial condition is the equilibrium spectrum for a population of constant size \(\nu_0\):
where \(\theta = 4 N_{\text{ref}} \mu\) is the population-scaled mutation rate. The boundary conditions enforce no flux at \(x = 0\) and \(x = 1\) (alleles cannot spontaneously appear or fix through drift alone; those transitions are handled by the mutation terms).
From \(\phi(x, t)\) to the SFS
Once the PDE is solved to the present time \(t_{\text{present}}\), the expected SFS is obtained by integrating the frequency density against binomial sampling weights:
This integral maps the continuous allele frequency density to the discrete frequency bins of a sample of size \(n\). The binomial coefficient accounts for the probability of sampling exactly \(j\) derived alleles from a site at population frequency \(x\).
Numerical discretization
dadi discretizes the frequency axis on a grid of \(G\) points \(0 = x_0 < x_1 < \ldots < x_G = 1\) and the time axis into piecewise-constant epochs. Within each epoch, the population size is constant, and the PDE is solved by the Crank-Nicolson finite-difference scheme:
where \(A\) is the tridiagonal matrix encoding the \(\frac{1}{2\nu}\frac{\partial^2}{\partial x^2}[x(1-x) \cdot]\) operator. Each time step requires solving a tridiagonal system – \(O(G)\) per step.
For \(k\) populations, the frequency axis becomes a \(k\)-dimensional grid with \(G^k\) points. The PDE gains cross-population terms for migration and population splits. The cost per time step becomes \(O(G^k)\), which is the fundamental bottleneck.
import numpy as np
def crank_nicolson_step(phi, dx, dt, nu):
"""One Crank-Nicolson step of the 1D diffusion PDE."""
G = len(phi)
x = np.linspace(0, 1, G)
diffusion = x * (1 - x) / (2 * nu)
diag = np.zeros(G)
upper = np.zeros(G - 1)
lower = np.zeros(G - 1)
for i in range(1, G - 1):
d2 = diffusion[i] / dx**2
diag[i] = -2 * d2
if i < G - 1:
upper[i] = d2
if i > 0:
lower[i - 1] = d2
A = np.diag(diag) + np.diag(upper, 1) + np.diag(lower, -1)
lhs = np.eye(G) - 0.5 * dt * A
rhs = (np.eye(G) + 0.5 * dt * A) @ phi
return np.linalg.solve(lhs, rhs)
Why piecewise-constant demography
The Crank-Nicolson scheme requires \(\nu\) to be constant within each time step. dadi approximates continuous demography by piecewise-constant functions – constant within each epoch, with instantaneous jumps between epochs. The number of epochs is a user-specified modeling choice. Too few epochs underfit the true demography; too many create identifiability issues. This is a limitation that Balance Wheel eliminates by learning a smooth mapping that can handle continuously varying \(N_e(t)\).
moments: ODE for the SFS Directly
While dadi works with the full frequency density \(\phi(x, t)\) and then integrates to get the SFS, moments takes a shortcut: it derives ODEs for the SFS entries themselves.
The key insight is that the SFS entries \(\Phi_j = E[D_j]\) are the \(j\)-th moments of the frequency distribution with respect to binomial sampling. The diffusion PDE for \(\phi(x, t)\) implies an ODE for each \(\Phi_j\):
where the drift operator is a tridiagonal linear transformation of the SFS vector, and mutation introduces new variants at frequency \(1/n\) (corresponding to a single derived allele in the sample).
In matrix notation:
where \(D\) is the \((n-1) \times (n-1)\) drift matrix and \(\mathbf{e}_1\) is the first standard basis vector.
Advantages over dadi
No frequency grid. moments works directly with the \(n - 1\) SFS entries, avoiding the \(G\)-point frequency grid. For small sample sizes, \(n - 1 \ll G\).
Sparse operators. The drift matrix \(D\) is tridiagonal. Each ODE step is \(O(n)\) for one population, versus \(O(G)\) for dadi. For \(k\) populations, the SFS has \(\prod_i (n_i - 1)\) entries and the drift operator is sparse – cost is \(O(n^k)\) but with a much smaller constant than dadi’s \(O(G^k)\).
Automatic differentiation. moments can differentiate through the ODE solver to compute exact gradients \(\nabla_\Theta \mathbf{M}(\Theta)\), avoiding the finite-difference approximation that dadi uses.
import numpy as np
from scipy.integrate import solve_ivp
def drift_matrix(n):
"""Tridiagonal drift matrix for 1D SFS of sample size n."""
size = n - 1
D = np.zeros((size, size))
for j in range(size):
jj = j + 1
if j > 0:
D[j, j - 1] = jj * (n - jj - 1)
D[j, j] = -2 * jj * (n - jj)
if j < size - 1:
D[j, j + 1] = (jj + 1) * (n - jj)
return D
def integrate_sfs(nu_func, times, n, theta):
"""Integrate the SFS ODE through a series of epochs."""
D = drift_matrix(n)
size = n - 1
mutation = np.zeros(size)
mutation[0] = theta
def rhs(t, phi, nu):
return (1.0 / (2 * nu)) * D @ phi + mutation
phi = np.zeros(size)
phi[0] = theta # initial condition: equilibrium singleton count
for i in range(len(times) - 1):
nu = nu_func(times[i])
dt = times[i + 1] - times[i]
sol = solve_ivp(lambda t, y: rhs(t, y, nu),
[0, dt], phi, method='RK45')
phi = sol.y[:, -1]
return phi
The remaining bottleneck
For \(k\) populations, the joint SFS has \(\prod_{i=1}^k (n_i - 1)\) entries. The drift operator for each population acts along one axis of the \(k\)-dimensional tensor. The cost per ODE step is \(O(n^k)\) (or more precisely, \(O(\prod n_i)\) times the cost of applying the drift operator along each axis). For \(k = 3\) with \(n = 20\):
Each ODE step requires applying the drift operator along each of the three axes – a total of \(O(3 \times 19 \times 19^3) \approx 10^5\) operations. With hundreds of ODE steps per epoch and multiple epochs, a single SFS evaluation can take minutes.
momi2: The Coalescent Path
momi2 takes a completely different approach. Instead of solving a PDE or ODE in forward time (present \(\to\) past), it uses coalescent theory to compute the expected SFS from the expected branch lengths of the genealogy.
The expected number of SNPs at frequency \(j/n\) is:
where \(b(e)\) is the branch length of edge \(e\) (in generations) and \(\text{desc}(e)\) is the number of descendant leaves below \(e\). This says: the expected SFS at frequency \(j\) is proportional to the total expected branch length subtending exactly \(j\) leaves.
momi2 computes these expected branch lengths using a tensor machinery that propagates through population splits, merges, and migration events (see momi2_tensor). For multi-population models, this can be more efficient than the PDE/ODE approach because the cost scales with the number of populations and their relationships, not with the grid resolution or sample size.
Three paths, one destination
dadi, moments, and momi2 all produce the same \(\mathbf{M}(\Theta)\) (up to numerical precision). They are three different algorithms for the same mathematical quantity. Balance Wheel adds a fourth path – the neural path – that is faster than all three but requires a one-time training cost.
Computational Bottlenecks
Why do these methods scale poorly, and what exactly limits them?
Method |
1-pop cost |
2-pop cost |
3-pop cost |
Wall-clock (3-pop, n=20) |
|---|---|---|---|---|
dadi |
\(O(G)\) |
\(O(G^2)\) |
\(O(G^3)\) |
Hours (if \(G \geq 40\)) |
moments |
\(O(n)\) |
\(O(n^2)\) |
\(O(n^3)\) |
~60 s |
momi2 |
\(O(n)\) |
\(O(n^2)\) |
\(O(n^3)\) |
~30 s (coalescent) |
Balance Wheel |
\(O(1)\) |
\(O(1)\) |
\(O(1)\) |
~0.1 ms |
The exponential scaling in \(k\) is the fundamental problem. For four or more populations, even moments becomes impractical. dadi’s authors recommend \(k \leq 3\) and warn that \(k = 3\) already requires hours.
The gradient computation compounds the issue:
dadi: uses finite differences. Each gradient evaluation requires \(2 \times |\Theta|\) forward SFS evaluations (two perturbations per parameter). For 10 parameters, that is 20 PDE solves per gradient step.
moments: uses automatic differentiation through the ODE solver, requiring only one forward pass plus one backward pass. But the backward pass through a stiff ODE can be numerically unstable, and the memory cost scales with the number of ODE steps (unless adjoint methods are used).
momi2: computes gradients analytically through its tensor machinery. Fast for small models, but the tensor contractions become expensive for complex topologies.
For Balance Wheel, the gradient is a single backpropagation through a small neural network – \(O(1)\) with a small constant, regardless of the number of populations or the complexity of the demographic model.
For HMC sampling, where each step requires ~10 gradient evaluations, this translates to 200 ms per HMC step (moments) versus 1 ms (Balance Wheel). The difference between impractical and routine.
What Balance Wheel Must Learn
To summarize, Balance Wheel’s SFS Predictor must learn a mapping \(\Theta \to \mathbf{M}(\Theta)\) that captures:
The effect of genetic drift on allele frequencies. Bottlenecks shift the SFS toward rare variants (excess singletons). Expansions flatten it. The timing and magnitude of size changes produce characteristic distortions.
The interaction of multiple populations. Migration homogenizes the joint SFS. Splits create private alleles. Ancient divergence produces shared low-frequency variants. The joint SFS encodes the full history of population interactions.
The integration over the mutation process. Each SFS entry is an integral (or sum) over the expected branch lengths of the genealogy, weighted by the mutation rate. The network must implicitly perform this integration.
The numerical constraints. All SFS entries must be positive (\(M_j > 0\)), and the total expected number of segregating sites \(\sum_j M_j = \theta L \cdot E[T_{\text{total}}]\) must be consistent with coalescent theory.
The remarkable fact is that a modest MLP (3–4 layers, 256 hidden units) can learn this mapping to high accuracy across a wide range of demographic models. The SFS is a smooth, low-dimensional function of \(\Theta\) – much smoother than, say, the mapping from genotypes to genealogies that Mainspring must learn. This is why Balance Wheel works with a simple architecture where Mainspring and Escapement require Transformers and GNNs.