Architecture

Three wheels, precisely meshed. The hairspring stores the demographic history as elastic energy. The balance wheel converts that energy into the oscillation of allele frequencies. The impulse pin couples the oscillation to the escape wheel of likelihood. Remove any one, and the mechanism keeps no time.

Balance Wheel has three modules, each with a distinct role. Unlike Escapement, which optimizes the ELBO on observed data, and Mainspring, which inverts simulations, Balance Wheel learns a function approximation: the mapping from demographic parameters to the expected SFS. The architecture is correspondingly simpler – no attention over samples, no variational posterior, no Gumbel-softmax. Just encode the demography, predict the SFS, and evaluate the likelihood.

Module 1                    Module 2                    Module 3
DEMOGRAPHY ENCODER          SFS PREDICTOR               POISSON LIKELIHOOD
(the hairspring)            (the balance wheel)         (the impulse pin)

Θ = (sizes, times,          z_Θ ∈ R^d                  M̂(Θ) ∈ R^{n-1}
     migration, ...)              |                          |
     |                            v                          v
     v                     ┌──────────────┐          ┌──────────────┐
┌──────────────┐           │   MLP        │          │  ℓ(Θ) = Σⱼ  │
│ Piecewise:   │           │   d → 256    │          │  [Dⱼ ln M̂ⱼ  │
│  Transformer │ ──────▶   │   256 → 256  │ ──────▶  │   - M̂ⱼ]    │
│ Continuous:  │   z_Θ     │   256 → n-1  │  M̂(Θ)   │             │
│  Neural ODE  │           │   softmax    │          │  (pure math)│
│ Multi-pop:   │           │              │          │             │
│  GNN         │           └──────────────┘          └──────────────┘
└──────────────┘                                           |
                                                           v
                                                  scalar log-likelihood
                                                  ∇_Θ ℓ via backprop

Module 1: Demography Encoder

The encoder transforms raw demographic parameters \(\Theta\) into a fixed- dimensional embedding \(\mathbf{z}_\Theta \in \mathbb{R}^d\). This embedding must capture the essential features of the demographic history – population sizes, their timing, population splits, and migration – in a form that the SFS Predictor can decode.

The challenge is that demographic models are variable-length: a two-epoch model has 3 parameters (one size, one size, one time), while a six-epoch model has 11. Multi-population models add split times, migration rates, and population tree topology. The encoder must handle all of these with a single architecture.

Piecewise-Constant Demography: Transformer

For piecewise-constant models (the most common case), the demographic history is a sequence of \((t_k, N_k)\) pairs: epoch start time and population size. This is naturally a sequence-to-vector problem – encode a variable-length sequence into a fixed-dimensional vector.

A small Transformer processes each epoch as a token:

import torch
import torch.nn as nn

class PiecewiseDemographyEncoder(nn.Module):
    def __init__(self, d_model=128, n_heads=4, n_layers=2,
                 max_epochs=10):
        super().__init__()
        self.d_model = d_model
        self.size_embed = nn.Linear(1, d_model)
        self.time_embed = nn.Linear(1, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=4*d_model,
            dropout=0.1, batch_first=True)
        self.transformer = nn.TransformerEncoder(
            encoder_layer, num_layers=n_layers)
        self.pool = nn.Sequential(
            nn.Linear(d_model, d_model), nn.GELU())

    def forward(self, log_sizes, log_times):
        """
        log_sizes: (batch, K) — log population sizes
        log_times: (batch, K) — log epoch start times
        """
        s = self.size_embed(log_sizes.unsqueeze(-1))
        t = self.time_embed(log_times.unsqueeze(-1))
        x = s + t
        x = self.transformer(x)
        return self.pool(x.mean(dim=1))

The Transformer’s self-attention allows each epoch to attend to every other epoch. This is important because the SFS depends on the relative sizes and cumulative durations of all epochs, not just their local values. A bottleneck followed by an expansion produces a qualitatively different SFS than the expansion alone – the Transformer captures these interactions.

Why log-space inputs

Population sizes span orders of magnitude (\(10^2\) to \(10^6\)) and times span from tens to millions of generations. Working in log-space normalizes these ranges and makes the optimization landscape smoother. The encoder receives \(\log N_e\) and \(\log t\), not the raw values.

Continuous Demography: Neural ODE

For continuous \(N_e(t)\) that cannot be represented as piecewise-constant, the encoder uses a neural ODE. The idea is to parameterize the dynamics of population size rather than its values at discrete change points:

\[\frac{d\mathbf{h}}{dt} = f_\theta(\mathbf{h}, t), \qquad \mathbf{z}_\Theta = \mathbf{h}(T)\]

where \(f_\theta\) is a small MLP and \(\mathbf{h}(t) \in \mathbb{R}^d\) is a latent state that evolves continuously from the present (\(t = 0\)) to the deepest time of interest (\(t = T\)). The initial condition \(\mathbf{h}(0)\) encodes the present-day population size.

from torchdiffeq import odeint

class ContinuousDemographyEncoder(nn.Module):
    def __init__(self, d_model=128):
        super().__init__()
        self.d_model = d_model
        self.initial = nn.Linear(1, d_model)
        self.dynamics = nn.Sequential(
            nn.Linear(d_model + 1, 256), nn.GELU(),
            nn.Linear(256, d_model))

    def forward(self, ne_func_params, t_eval):
        """
        ne_func_params: parameters defining N_e(t) (e.g., spline knots)
        t_eval: (T,) — time points for ODE integration
        """
        h0 = self.initial(ne_func_params[:, :1])

        def odefunc(t, h):
            t_input = t.expand(h.shape[0], 1)
            return self.dynamics(torch.cat([h, t_input], dim=-1))

        trajectory = odeint(odefunc, h0, t_eval, method='dopri5')
        return trajectory[-1]

This handles arbitrarily complex demographic trajectories – exponential growth, oscillations, gradual bottlenecks – without requiring the user to specify the number of epochs.

Multi-Population Models: GNN

For \(k\) populations with a known topology (splits, merges, migration edges), the demographic history forms a directed graph. Each node is a population at a given time, and edges represent lineage flow (descent, migration). A graph neural network (GNN) encodes this structure:

import torch
import torch.nn as nn

class PopulationTreeEncoder(nn.Module):
    def __init__(self, d_model=128, n_layers=3):
        super().__init__()
        self.node_embed = nn.Linear(3, d_model)
        self.message_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(2 * d_model + 1, d_model), nn.GELU(),
                nn.Linear(d_model, d_model))
            for _ in range(n_layers)])
        self.update_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(2 * d_model, d_model), nn.GELU(),
                nn.Linear(d_model, d_model))
            for _ in range(n_layers)])
        self.norms = nn.ModuleList([
            nn.LayerNorm(d_model) for _ in range(n_layers)])
        self.readout = nn.Sequential(
            nn.Linear(d_model, d_model), nn.GELU())

    def forward(self, node_features, edge_index, edge_attr):
        """
        node_features: (n_nodes, 3) — [log_size, log_time, pop_id]
        edge_index: (2, n_edges) — source and target indices
        edge_attr: (n_edges, 1) — migration rates (0 for descent edges)
        """
        h = self.node_embed(node_features)

        for msg_fn, upd_fn, norm in zip(
                self.message_layers, self.update_layers, self.norms):
            src, dst = edge_index
            messages = msg_fn(torch.cat([
                h[src], h[dst], edge_attr], dim=-1))
            agg = torch.zeros_like(h)
            agg.index_add_(0, dst, messages)
            h = norm(h + upd_fn(torch.cat([h, agg], dim=-1)))

        return self.readout(h.mean(dim=0, keepdim=True))

Each node represents a population at a given time, with features \([\log N_e, \log t, \text{pop\_id}]\). Edges represent descent (from ancestral to descendant population) or migration (between contemporary populations). The GNN propagates information along these edges, allowing the embedding to capture the full topology.

Module 2: SFS Predictor

The SFS Predictor maps the demographic embedding \(\mathbf{z}_\Theta\) to the expected SFS \(\hat{\mathbf{M}}(\Theta)\). This is the core of Balance Wheel – the module that replaces the PDE/ODE solver.

Architecture

The predictor is a standard MLP with a softmax output to ensure positivity:

class SFSPredictor(nn.Module):
    def __init__(self, d_model=128, hidden=256, n_layers=4, max_n=100):
        super().__init__()
        self.max_n = max_n
        layers = [nn.Linear(d_model + 1, hidden), nn.GELU()]
        for _ in range(n_layers - 2):
            layers.extend([nn.Linear(hidden, hidden), nn.GELU()])
        layers.append(nn.Linear(hidden, max_n))
        self.mlp = nn.Sequential(*layers)

    def forward(self, z, n, theta_L):
        """
        z: (batch, d_model) — demographic embedding
        n: int — sample size
        theta_L: float — θ · L scaling factor
        """
        n_input = torch.full((z.shape[0], 1), n / self.max_n,
                             device=z.device)
        raw = self.mlp(torch.cat([z, n_input], dim=-1))
        raw = raw[:, :n - 1]
        sfs = torch.softmax(raw, dim=-1) * theta_L
        return sfs

The softmax ensures two constraints:

  1. Positivity: \(\hat{M}_j > 0\) for all \(j\). This is essential because the Poisson log-likelihood requires \(\ln M_j\), which is undefined for \(M_j \leq 0\).

  2. Normalization: \(\sum_j \hat{M}_j = \theta L\). The total expected number of segregating sites is fixed by the mutation rate and sequence length. The softmax distributes this total across frequency bins.

Why softmax and not ReLU

A naive approach would be to use a ReLU output (ensuring positivity) and an unnormalized output. This fails for two reasons. First, ReLU can produce exact zeros, which cause \(\ln 0 = -\infty\) in the Poisson likelihood. Second, without normalization, the total \(\sum \hat{M}_j\) is unconstrained and the network must learn the correct total from data – an unnecessary burden. Softmax handles both constraints simultaneously.

Sample-size conditioning

The predictor takes the sample size \(n\) as an input (normalized by \(n_{\max}\)). This allows a single trained network to predict the SFS for any sample size up to \(n_{\max}\), rather than training separate networks for each \(n\). The sample size affects the SFS through the binomial sampling step – larger \(n\) resolves finer frequency classes – and the network must learn this dependence.

Module 3: Poisson Likelihood

Module 3 contains no learnable parameters. It is pure mathematics: the exact same Poisson log-likelihood that dadi and moments optimize.

\[\ell(\Theta) = \sum_{j=1}^{n-1} \left[ D_j \ln \hat{M}_j(\Theta) - \hat{M}_j(\Theta) - \ln(D_j!) \right]\]

where \(D_j\) is the observed SFS count and \(\hat{M}_j(\Theta)\) is the neural prediction. The last term is constant and can be dropped during optimization.

def poisson_log_likelihood(observed_sfs, expected_sfs):
    """Exact Poisson log-likelihood — identical to dadi/moments.

    observed_sfs: (n-1,) integer tensor — observed SFS counts
    expected_sfs: (n-1,) positive tensor — predicted expected SFS
    """
    M = expected_sfs.clamp(min=1e-10)
    ll = observed_sfs * torch.log(M) - M - torch.lgamma(observed_sfs + 1)
    return ll.sum()

The critical property of this module is what it does not do: it does not approximate the likelihood. The Poisson model is exact (under the PRF assumptions). The only approximation in Balance Wheel is in Module 2 – the neural SFS prediction \(\hat{\mathbf{M}}(\Theta) \approx \mathbf{M}(\Theta)\). The likelihood evaluation itself is lossless.

Gradient flow

The gradient of the log-likelihood with respect to \(\Theta\) flows through all three modules via backpropagation:

\[\nabla_\Theta \,\ell = \sum_{j=1}^{n-1} \left(\frac{D_j}{\hat{M}_j} - 1\right) \cdot \nabla_\Theta \hat{M}_j\]

The term \(\nabla_\Theta \hat{M}_j\) is the Jacobian of the neural SFS prediction with respect to the demographic parameters. PyTorch computes this automatically via its autograd engine. No finite differences, no adjoint ODE solves, no numerical Jacobians.

Gradient quality comparison

  • dadi: finite differences. \(\nabla_{\Theta_i} \ell \approx [\ell(\Theta + \epsilon \mathbf{e}_i) - \ell(\Theta - \epsilon \mathbf{e}_i)] / (2\epsilon)\). Requires \(2|\Theta|\) forward solves. Sensitive to \(\epsilon\) choice. Can be inaccurate for nearly flat likelihoods.

  • moments: AD through ODE solver. Exact in principle, but can suffer from numerical instability for stiff systems (large population size changes) and high memory cost (storing intermediate ODE states for backprop).

  • Balance Wheel: standard backprop through a small MLP. Exact, fast, numerically stable. The MLP has no stiff dynamics, no PDE to discretize, no frequency grid to resolve.

Putting It All Together

The complete Balance Wheel model chains the three modules:

class BalanceWheel(nn.Module):
    def __init__(self, d_model=128, n_heads=4, n_layers_enc=2,
                 n_layers_pred=4, hidden=256, max_n=100):
        super().__init__()
        self.encoder = PiecewiseDemographyEncoder(
            d_model=d_model, n_heads=n_heads, n_layers=n_layers_enc)
        self.predictor = SFSPredictor(
            d_model=d_model, hidden=hidden,
            n_layers=n_layers_pred, max_n=max_n)

    def forward(self, log_sizes, log_times, n, theta_L):
        z = self.encoder(log_sizes, log_times)
        expected_sfs = self.predictor(z, n, theta_L)
        return expected_sfs

    def log_likelihood(self, log_sizes, log_times, observed_sfs, theta_L):
        n = observed_sfs.shape[-1] + 1
        expected_sfs = self.forward(log_sizes, log_times, n, theta_L)
        return poisson_log_likelihood(observed_sfs, expected_sfs)
Computational complexity per module

Module

Complexity

Bottleneck

Demography Encoder

\(O(K^2 d)\) (Transformer) or \(O(Sd)\) (Neural ODE)

Self-attention over \(K\) epochs (typically \(K \leq 10\))

SFS Predictor

\(O(H^2 L_{\text{layers}})\) (MLP)

Matrix multiplications in hidden layers

Poisson Likelihood

\(O(n)\)

One pass over \(n - 1\) SFS entries

Backpropagation

Same as forward

Automatic; no additional algorithmic complexity

Total: \(O(K^2 d + H^2 L_{\text{layers}} + n)\), dominated by the MLP’s hidden layers. For typical parameters (\(K \leq 10\), \(d = 128\), \(H = 256\), \(L_{\text{layers}} = 4\), \(n \leq 100\)), the forward pass takes ~0.1 ms on a modern GPU. Compare this to moments’ ~10 ms for a single population or dadi’s ~100 ms – a 100–1000× speedup.

The architecture is deliberately simple

Balance Wheel uses MLPs where Mainspring uses Transformers and Escapement uses variational posteriors with Gumbel-softmax. This simplicity is not a limitation – it reflects the nature of the problem. The mapping \(\Theta \to \mathbf{M}(\Theta)\) is a smooth, low-dimensional function. It does not require the expressive power of attention mechanisms or the statistical sophistication of variational inference. A well-trained MLP is the right tool for this job, just as the balance wheel – a simple oscillating mass – is the right regulator for a mechanical watch.