.. _balance_wheel_architecture:

==============
Architecture
==============

   *Three wheels, precisely meshed. The hairspring stores the demographic
   history as elastic energy. The balance wheel converts that energy into the
   oscillation of allele frequencies. The impulse pin couples the oscillation
   to the escape wheel of likelihood. Remove any one, and the mechanism keeps
   no time.*

Balance Wheel has three modules, each with a distinct role. Unlike
:ref:`Escapement <escapement_complication>`, which optimizes the ELBO on
observed data, and :ref:`Mainspring <mainspring_complication>`, which inverts
simulations, Balance Wheel learns a *function approximation*: the mapping from
demographic parameters to the expected SFS. The architecture is correspondingly
simpler -- no attention over samples, no variational posterior, no Gumbel-softmax.
Just encode the demography, predict the SFS, and evaluate the likelihood.

.. code-block:: text

   Module 1                    Module 2                    Module 3
   DEMOGRAPHY ENCODER          SFS PREDICTOR               POISSON LIKELIHOOD
   (the hairspring)            (the balance wheel)         (the impulse pin)

   Θ = (sizes, times,          z_Θ ∈ R^d                  M̂(Θ) ∈ R^{n-1}
        migration, ...)              |                          |
        |                            v                          v
        v                     ┌──────────────┐          ┌──────────────┐
   ┌──────────────┐           │   MLP        │          │  ℓ(Θ) = Σⱼ  │
   │ Piecewise:   │           │   d → 256    │          │  [Dⱼ ln M̂ⱼ  │
   │  Transformer │ ──────▶   │   256 → 256  │ ──────▶  │   - M̂ⱼ]    │
   │ Continuous:  │   z_Θ     │   256 → n-1  │  M̂(Θ)   │             │
   │  Neural ODE  │           │   softmax    │          │  (pure math)│
   │ Multi-pop:   │           │              │          │             │
   │  GNN         │           └──────────────┘          └──────────────┘
   └──────────────┘                                           |
                                                              v
                                                     scalar log-likelihood
                                                     ∇_Θ ℓ via backprop


Module 1: Demography Encoder
===============================

The encoder transforms raw demographic parameters :math:`\Theta` into a fixed-
dimensional embedding :math:`\mathbf{z}_\Theta \in \mathbb{R}^d`. This
embedding must capture the essential features of the demographic history --
population sizes, their timing, population splits, and migration -- in a form
that the SFS Predictor can decode.

The challenge is that demographic models are variable-length: a two-epoch model
has 3 parameters (one size, one size, one time), while a six-epoch model has 11.
Multi-population models add split times, migration rates, and population tree
topology. The encoder must handle all of these with a single architecture.

Piecewise-Constant Demography: Transformer
---------------------------------------------

For piecewise-constant models (the most common case), the demographic history is
a sequence of :math:`(t_k, N_k)` pairs: epoch start time and population size.
This is naturally a sequence-to-vector problem -- encode a variable-length
sequence into a fixed-dimensional vector.

A small Transformer processes each epoch as a token:

.. code-block:: python

   import torch
   import torch.nn as nn

   class PiecewiseDemographyEncoder(nn.Module):
       def __init__(self, d_model=128, n_heads=4, n_layers=2,
                    max_epochs=10):
           super().__init__()
           self.d_model = d_model
           self.size_embed = nn.Linear(1, d_model)
           self.time_embed = nn.Linear(1, d_model)
           encoder_layer = nn.TransformerEncoderLayer(
               d_model=d_model, nhead=n_heads, dim_feedforward=4*d_model,
               dropout=0.1, batch_first=True)
           self.transformer = nn.TransformerEncoder(
               encoder_layer, num_layers=n_layers)
           self.pool = nn.Sequential(
               nn.Linear(d_model, d_model), nn.GELU())

       def forward(self, log_sizes, log_times):
           """
           log_sizes: (batch, K) — log population sizes
           log_times: (batch, K) — log epoch start times
           """
           s = self.size_embed(log_sizes.unsqueeze(-1))
           t = self.time_embed(log_times.unsqueeze(-1))
           x = s + t
           x = self.transformer(x)
           return self.pool(x.mean(dim=1))

The Transformer's self-attention allows each epoch to attend to every other
epoch. This is important because the SFS depends on the *relative* sizes and
*cumulative* durations of all epochs, not just their local values. A bottleneck
followed by an expansion produces a qualitatively different SFS than the
expansion alone -- the Transformer captures these interactions.

.. admonition:: Why log-space inputs

   Population sizes span orders of magnitude (:math:`10^2` to :math:`10^6`) and
   times span from tens to millions of generations. Working in log-space
   normalizes these ranges and makes the optimization landscape smoother.
   The encoder receives :math:`\log N_e` and :math:`\log t`, not the raw values.

Continuous Demography: Neural ODE
------------------------------------

For continuous :math:`N_e(t)` that cannot be represented as piecewise-constant,
the encoder uses a neural ODE. The idea is to parameterize the *dynamics* of
population size rather than its values at discrete change points:

.. math::

   \frac{d\mathbf{h}}{dt} = f_\theta(\mathbf{h}, t), \qquad
   \mathbf{z}_\Theta = \mathbf{h}(T)

where :math:`f_\theta` is a small MLP and :math:`\mathbf{h}(t) \in \mathbb{R}^d`
is a latent state that evolves continuously from the present (:math:`t = 0`) to
the deepest time of interest (:math:`t = T`). The initial condition
:math:`\mathbf{h}(0)` encodes the present-day population size.

.. code-block:: python

   from torchdiffeq import odeint

   class ContinuousDemographyEncoder(nn.Module):
       def __init__(self, d_model=128):
           super().__init__()
           self.d_model = d_model
           self.initial = nn.Linear(1, d_model)
           self.dynamics = nn.Sequential(
               nn.Linear(d_model + 1, 256), nn.GELU(),
               nn.Linear(256, d_model))

       def forward(self, ne_func_params, t_eval):
           """
           ne_func_params: parameters defining N_e(t) (e.g., spline knots)
           t_eval: (T,) — time points for ODE integration
           """
           h0 = self.initial(ne_func_params[:, :1])

           def odefunc(t, h):
               t_input = t.expand(h.shape[0], 1)
               return self.dynamics(torch.cat([h, t_input], dim=-1))

           trajectory = odeint(odefunc, h0, t_eval, method='dopri5')
           return trajectory[-1]

This handles arbitrarily complex demographic trajectories -- exponential growth,
oscillations, gradual bottlenecks -- without requiring the user to specify
the number of epochs.

Multi-Population Models: GNN
-------------------------------

For :math:`k` populations with a known topology (splits, merges, migration
edges), the demographic history forms a directed graph. Each node is a population
at a given time, and edges represent lineage flow (descent, migration). A graph
neural network (GNN) encodes this structure:

.. code-block:: python

   import torch
   import torch.nn as nn

   class PopulationTreeEncoder(nn.Module):
       def __init__(self, d_model=128, n_layers=3):
           super().__init__()
           self.node_embed = nn.Linear(3, d_model)
           self.message_layers = nn.ModuleList([
               nn.Sequential(
                   nn.Linear(2 * d_model + 1, d_model), nn.GELU(),
                   nn.Linear(d_model, d_model))
               for _ in range(n_layers)])
           self.update_layers = nn.ModuleList([
               nn.Sequential(
                   nn.Linear(2 * d_model, d_model), nn.GELU(),
                   nn.Linear(d_model, d_model))
               for _ in range(n_layers)])
           self.norms = nn.ModuleList([
               nn.LayerNorm(d_model) for _ in range(n_layers)])
           self.readout = nn.Sequential(
               nn.Linear(d_model, d_model), nn.GELU())

       def forward(self, node_features, edge_index, edge_attr):
           """
           node_features: (n_nodes, 3) — [log_size, log_time, pop_id]
           edge_index: (2, n_edges) — source and target indices
           edge_attr: (n_edges, 1) — migration rates (0 for descent edges)
           """
           h = self.node_embed(node_features)

           for msg_fn, upd_fn, norm in zip(
                   self.message_layers, self.update_layers, self.norms):
               src, dst = edge_index
               messages = msg_fn(torch.cat([
                   h[src], h[dst], edge_attr], dim=-1))
               agg = torch.zeros_like(h)
               agg.index_add_(0, dst, messages)
               h = norm(h + upd_fn(torch.cat([h, agg], dim=-1)))

           return self.readout(h.mean(dim=0, keepdim=True))

Each node represents a population at a given time, with features
:math:`[\log N_e, \log t, \text{pop\_id}]`. Edges represent descent (from
ancestral to descendant population) or migration (between contemporary
populations). The GNN propagates information along these edges, allowing the
embedding to capture the full topology.


Module 2: SFS Predictor
==========================

The SFS Predictor maps the demographic embedding :math:`\mathbf{z}_\Theta` to
the expected SFS :math:`\hat{\mathbf{M}}(\Theta)`. This is the core of Balance
Wheel -- the module that replaces the PDE/ODE solver.

Architecture
--------------

The predictor is a standard MLP with a softmax output to ensure positivity:

.. code-block:: python

   class SFSPredictor(nn.Module):
       def __init__(self, d_model=128, hidden=256, n_layers=4, max_n=100):
           super().__init__()
           self.max_n = max_n
           layers = [nn.Linear(d_model + 1, hidden), nn.GELU()]
           for _ in range(n_layers - 2):
               layers.extend([nn.Linear(hidden, hidden), nn.GELU()])
           layers.append(nn.Linear(hidden, max_n))
           self.mlp = nn.Sequential(*layers)

       def forward(self, z, n, theta_L):
           """
           z: (batch, d_model) — demographic embedding
           n: int — sample size
           theta_L: float — θ · L scaling factor
           """
           n_input = torch.full((z.shape[0], 1), n / self.max_n,
                                device=z.device)
           raw = self.mlp(torch.cat([z, n_input], dim=-1))
           raw = raw[:, :n - 1]
           sfs = torch.softmax(raw, dim=-1) * theta_L
           return sfs

The softmax ensures two constraints:

1. **Positivity**: :math:`\hat{M}_j > 0` for all :math:`j`. This is essential
   because the Poisson log-likelihood requires :math:`\ln M_j`, which is
   undefined for :math:`M_j \leq 0`.

2. **Normalization**: :math:`\sum_j \hat{M}_j = \theta L`. The total expected
   number of segregating sites is fixed by the mutation rate and sequence length.
   The softmax distributes this total across frequency bins.

.. admonition:: Why softmax and not ReLU

   A naive approach would be to use a ReLU output (ensuring positivity) and an
   unnormalized output. This fails for two reasons. First, ReLU can produce exact
   zeros, which cause :math:`\ln 0 = -\infty` in the Poisson likelihood. Second,
   without normalization, the total :math:`\sum \hat{M}_j` is unconstrained and
   the network must learn the correct total from data -- an unnecessary burden.
   Softmax handles both constraints simultaneously.

Sample-size conditioning
--------------------------

The predictor takes the sample size :math:`n` as an input (normalized by
:math:`n_{\max}`). This allows a single trained network to predict the SFS for
any sample size up to :math:`n_{\max}`, rather than training separate networks
for each :math:`n`. The sample size affects the SFS through the binomial sampling
step -- larger :math:`n` resolves finer frequency classes -- and the network
must learn this dependence.


Module 3: Poisson Likelihood
===============================

Module 3 contains no learnable parameters. It is pure mathematics: the exact
same Poisson log-likelihood that :ref:`dadi <dadi_timepiece>` and
:ref:`moments <moments_timepiece>` optimize.

.. math::

   \ell(\Theta) = \sum_{j=1}^{n-1} \left[
   D_j \ln \hat{M}_j(\Theta) - \hat{M}_j(\Theta) - \ln(D_j!)
   \right]

where :math:`D_j` is the observed SFS count and :math:`\hat{M}_j(\Theta)` is
the neural prediction. The last term is constant and can be dropped during
optimization.

.. code-block:: python

   def poisson_log_likelihood(observed_sfs, expected_sfs):
       """Exact Poisson log-likelihood — identical to dadi/moments.

       observed_sfs: (n-1,) integer tensor — observed SFS counts
       expected_sfs: (n-1,) positive tensor — predicted expected SFS
       """
       M = expected_sfs.clamp(min=1e-10)
       ll = observed_sfs * torch.log(M) - M - torch.lgamma(observed_sfs + 1)
       return ll.sum()

The critical property of this module is what it does *not* do: it does not
approximate the likelihood. The Poisson model is exact (under the PRF
assumptions). The only approximation in Balance Wheel is in Module 2 -- the
neural SFS prediction :math:`\hat{\mathbf{M}}(\Theta) \approx \mathbf{M}(\Theta)`.
The likelihood evaluation itself is lossless.

Gradient flow
---------------

The gradient of the log-likelihood with respect to :math:`\Theta` flows through
all three modules via backpropagation:

.. math::

   \nabla_\Theta \,\ell = \sum_{j=1}^{n-1} \left(\frac{D_j}{\hat{M}_j} - 1\right)
   \cdot \nabla_\Theta \hat{M}_j

The term :math:`\nabla_\Theta \hat{M}_j` is the Jacobian of the neural SFS
prediction with respect to the demographic parameters. PyTorch computes this
automatically via its autograd engine. No finite differences, no adjoint ODE
solves, no numerical Jacobians.

.. admonition:: Gradient quality comparison

   - **dadi**: finite differences. :math:`\nabla_{\Theta_i} \ell \approx
     [\ell(\Theta + \epsilon \mathbf{e}_i) - \ell(\Theta - \epsilon \mathbf{e}_i)]
     / (2\epsilon)`. Requires :math:`2|\Theta|` forward solves. Sensitive to
     :math:`\epsilon` choice. Can be inaccurate for nearly flat likelihoods.

   - **moments**: AD through ODE solver. Exact in principle, but can suffer from
     numerical instability for stiff systems (large population size changes) and
     high memory cost (storing intermediate ODE states for backprop).

   - **Balance Wheel**: standard backprop through a small MLP. Exact, fast,
     numerically stable. The MLP has no stiff dynamics, no PDE to discretize,
     no frequency grid to resolve.


Putting It All Together
=========================

The complete Balance Wheel model chains the three modules:

.. code-block:: python

   class BalanceWheel(nn.Module):
       def __init__(self, d_model=128, n_heads=4, n_layers_enc=2,
                    n_layers_pred=4, hidden=256, max_n=100):
           super().__init__()
           self.encoder = PiecewiseDemographyEncoder(
               d_model=d_model, n_heads=n_heads, n_layers=n_layers_enc)
           self.predictor = SFSPredictor(
               d_model=d_model, hidden=hidden,
               n_layers=n_layers_pred, max_n=max_n)

       def forward(self, log_sizes, log_times, n, theta_L):
           z = self.encoder(log_sizes, log_times)
           expected_sfs = self.predictor(z, n, theta_L)
           return expected_sfs

       def log_likelihood(self, log_sizes, log_times, observed_sfs, theta_L):
           n = observed_sfs.shape[-1] + 1
           expected_sfs = self.forward(log_sizes, log_times, n, theta_L)
           return poisson_log_likelihood(observed_sfs, expected_sfs)

.. list-table:: Computational complexity per module
   :header-rows: 1
   :widths: 25 25 50

   * - Module
     - Complexity
     - Bottleneck
   * - Demography Encoder
     - :math:`O(K^2 d)` (Transformer) or :math:`O(Sd)` (Neural ODE)
     - Self-attention over :math:`K` epochs (typically :math:`K \leq 10`)
   * - SFS Predictor
     - :math:`O(H^2 L_{\text{layers}})` (MLP)
     - Matrix multiplications in hidden layers
   * - Poisson Likelihood
     - :math:`O(n)`
     - One pass over :math:`n - 1` SFS entries
   * - Backpropagation
     - Same as forward
     - Automatic; no additional algorithmic complexity

Total: :math:`O(K^2 d + H^2 L_{\text{layers}} + n)`, dominated by the MLP's
hidden layers. For typical parameters (:math:`K \leq 10`, :math:`d = 128`,
:math:`H = 256`, :math:`L_{\text{layers}} = 4`, :math:`n \leq 100`), the
forward pass takes ~0.1 ms on a modern GPU. Compare this to moments' ~10 ms
for a single population or dadi's ~100 ms -- a 100--1000× speedup.

.. admonition:: The architecture is deliberately simple

   Balance Wheel uses MLPs where Mainspring uses Transformers and Escapement uses
   variational posteriors with Gumbel-softmax. This simplicity is not a
   limitation -- it reflects the nature of the problem. The mapping
   :math:`\Theta \to \mathbf{M}(\Theta)` is a smooth, low-dimensional function.
   It does not require the expressive power of attention mechanisms or the
   statistical sophistication of variational inference. A well-trained MLP is
   the right tool for this job, just as the balance wheel -- a simple oscillating
   mass -- is the right regulator for a mechanical watch.
