.. _mainspring_design_principles:

==========================================
Design Principles -- One Per Timepiece
==========================================

   *A complication is not an ornament. Every additional mechanism must earn its place
   by solving a problem that simpler mechanisms cannot.*

Mainspring's architecture is not arbitrary. Every major design choice can be traced
back to a specific mathematical insight from a specific Timepiece. This chapter
catalogues ten such principles -- one from each Timepiece in the book -- and explains
how each is realized in the neural network.

The principles are not optional features. Remove any one and the network's performance
degrades measurably. Together they form the inductive bias that allows Mainspring to
learn from millions of simulations what would take billions without structure.


Principle 1: Sequential Markov Structure
==========================================

*From* :ref:`PSMC <psmc_timepiece>`

.. admonition:: Principle

   Along the genome, the coalescence time at position :math:`\ell + 1` depends on the
   coalescence time at position :math:`\ell` but not on positions :math:`1, \ldots,
   \ell - 1`. This is the **sequential Markov property** -- the foundation of every
   SMC-based method.

**What PSMC does.** PSMC models the sequence of coalescence times as a hidden Markov
chain along the genome. The transition matrix encodes the probability that a
recombination event changes the genealogy between adjacent genomic bins, and the
probability of re-coalescence at each possible time (see :ref:`psmc_continuous`).

**Realization in Mainspring.** The genomic encoder uses **sliding-window causal
attention**. Each genomic position attends only to positions within a window of
:math:`w` sites (default :math:`w = 512`), and attention is masked so that position
:math:`\ell` cannot attend to positions beyond :math:`\ell + w`. This is the neural
analogue of the Markov property: information propagates locally along the genome,
respecting the fact that recombination decorrelates genealogies over distance.

.. math::

   \text{Attention}(\ell, \ell') = \begin{cases}
   \text{softmax}\!\bigl(\mathbf{q}_\ell^\top \mathbf{k}_{\ell'} / \sqrt{d}\bigr)
   & \text{if } |\ell - \ell'| \leq w \\
   0 & \text{otherwise}
   \end{cases}

The window size :math:`w` is a hyperparameter analogous to the correlation length in
the SMC: it should be large enough to capture the typical distance between
recombination breakpoints. In practice, we set :math:`w` to cover roughly
:math:`2/\rho` sites, where :math:`\rho` is the per-site recombination rate.

.. code-block:: python

   import torch
   import torch.nn as nn

   class SlidingWindowAttention(nn.Module):
       def __init__(self, d_model, n_heads, window_size):
           super().__init__()
           self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
           self.window_size = window_size

       def forward(self, x):
           L = x.size(1)
           mask = torch.ones(L, L, dtype=torch.bool, device=x.device)
           for i in range(L):
               lo = max(0, i - self.window_size)
               hi = min(L, i + self.window_size + 1)
               mask[i, lo:hi] = False
           out, _ = self.attn(x, x, x, attn_mask=mask)
           return out


Principle 2: Permutation Invariance over Samples
==================================================

*From* :ref:`SMC++ <smcpp_timepiece>`

.. admonition:: Principle

   Under the coalescent, the labels of the :math:`n` sampled haplotypes are
   **exchangeable**: the likelihood of a genotype matrix is invariant to
   reordering the rows. Any architecture that processes samples must respect this
   symmetry.

**What SMC++ does.** SMC++ handles multiple samples by designating one
**distinguished lineage** and treating the remaining :math:`n - 1` samples
symmetrically (see :ref:`smcpp_distinguished`). The undistinguished lineages
contribute through a permutation-invariant sufficient statistic: at each site, only
the allele frequency among the undistinguished samples matters, not which specific
samples carry the derived allele.

**Realization in Mainspring.** The genomic encoder applies a **Set Transformer**
(Lee et al. 2019) over the sample dimension. At each genomic position, the :math:`n`
sample embeddings are processed by induced set attention blocks (ISAB) that are
exactly equivariant to permutations of the input rows.

.. math::

   \mathbf{H}_\ell = \text{SetTransformer}\!\bigl(\mathbf{d}_{1,\ell},\;
   \mathbf{d}_{2,\ell},\; \ldots,\; \mathbf{d}_{n,\ell}\bigr)

where :math:`\mathbf{d}_{i,\ell}` is the embedding of sample :math:`i` at position
:math:`\ell`, and :math:`\mathbf{H}_\ell \in \mathbb{R}^{n \times d}` is the output
embedding matrix. Crucially, permuting the rows of the input permutes the rows of
the output in the same way -- the network cannot distinguish sample 1 from sample 7
unless the data itself distinguishes them.

.. code-block:: python

   class InducedSetAttention(nn.Module):
       def __init__(self, d_model, n_heads, n_inducing):
           super().__init__()
           self.inducing = nn.Parameter(torch.randn(1, n_inducing, d_model))
           self.attn_to_inducing = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
           self.attn_from_inducing = nn.MultiheadAttention(d_model, n_heads, batch_first=True)

       def forward(self, x):
           B = x.size(0)
           ind = self.inducing.expand(B, -1, -1)
           h, _ = self.attn_to_inducing(ind, x, x)
           out, _ = self.attn_from_inducing(x, h, h)
           return out


Principle 3: The Full ARG as Sufficient Statistic
===================================================

*From* :ref:`ARGweaver <argweaver_timepiece>` *and* :ref:`SINGER <singer_timepiece>`

.. admonition:: Principle

   The ancestral recombination graph is a **sufficient statistic** for the data
   under the coalescent-with-recombination model. If you know the true ARG, no
   additional information in the genotype matrix can improve your inference of
   demographic parameters.

**What ARGweaver and SINGER do.** Both methods sample from the posterior distribution
over ARGs given the data. ARGweaver uses a threading-based MCMC scheme (see
:ref:`argweaver_mcmc`); SINGER uses Gibbs sampling with a Gaussian process prior on
branch lengths (see :ref:`branch_sampling`). Both produce full ARGs -- topology,
breakpoints, and node times -- as output.

**Realization in Mainspring.** Rather than predicting scalar summary statistics,
Mainspring's topology decoder and dating GNN jointly produce a **full ARG in tskit
format**. The network outputs:

- A sequence of local tree topologies (parent arrays)
- Breakpoint positions where trees change
- Node times for every internal node in every local tree

This is a much harder prediction target than a scalar, but it is the right one:
if the ARG is the sufficient statistic, then a network that predicts the ARG
captures all the information in the data. The demographic decoder then operates on
the predicted ARG, not on the raw genotype matrix, mirroring the factorization

.. math::

   p(N_e \mid \mathbf{D}) = \int p(N_e \mid \mathcal{A})\, p(\mathcal{A} \mid \mathbf{D})\, d\mathcal{A}

where :math:`\mathcal{A}` is the ARG and :math:`\mathbf{D}` is the data.


Principle 4: Message Passing for Dating
=========================================

*From* :ref:`tsdate <tsdate_timepiece>`

.. admonition:: Principle

   Given a tree topology and mutation counts on edges, node times can be inferred by
   **message passing** -- propagating information upward from leaves (inside pass) and
   downward from the root (outside pass). This is belief propagation on a tree-shaped
   graphical model.

**What tsdate does.** tsdate runs an inside-outside algorithm on each local tree in a
tree sequence. The inside pass aggregates mutation evidence from leaves toward the
root; the outside pass propagates coalescent-prior information from the root toward
the leaves. The result is a posterior distribution over node times that combines data
(mutations) with prior (coalescent) (see :ref:`tsdate_inside_outside`).

**Realization in Mainspring.** The dating GNN implements a **learned message-passing
scheme** on each local tree. Each node :math:`v` has a feature vector :math:`\mathbf{h}_v`
that is updated by aggregating messages from its parent and children:

.. math::

   \mathbf{h}_v^{(k+1)} = \text{Update}\!\Bigl(\mathbf{h}_v^{(k)},\;
   \text{Agg}\bigl(\{\mathbf{m}_{u \to v}^{(k)} : u \in \mathcal{N}(v)\}\bigr)\Bigr)

where :math:`\mathcal{N}(v)` includes the parent and children of :math:`v`, and the
message function :math:`\mathbf{m}_{u \to v}` depends on the edge features (mutation
count, genomic span) between :math:`u` and :math:`v`. After :math:`K` rounds of
message passing, the final node features are decoded into gamma-distributed time
posteriors.

The key difference from tsdate is that the message and update functions are
**learned**, not derived from the coalescent likelihood. This allows the GNN to
capture correlations that tsdate's factored approximation misses -- for example,
the constraint that a parent node must be older than all its children, which tsdate
enforces post-hoc but which the GNN can learn to respect automatically.


Principle 5: Li & Stephens as Attention
=========================================

*From* :ref:`tsinfer <tsinfer_timepiece>` *and* :ref:`Threads <threads_timepiece>`

.. admonition:: Principle

   The Li & Stephens model says that each haplotype is an imperfect mosaic copy of
   other haplotypes. The pattern of "who copies whom" at each position defines the
   local tree topology. This copying relationship is formally equivalent to an
   **attention mechanism** where each haplotype attends to the haplotype it is
   copying.

**What tsinfer does.** tsinfer matches each sample haplotype against a panel of
ancestral haplotypes using the Li & Stephens HMM (see :ref:`tsinfer_copying_model`).
At each site, the Viterbi path identifies which ancestor the sample is copying. A
recombination event corresponds to a switch in the copying source -- a change in
attention.

**Realization in Mainspring.** The topology decoder uses **cross-attention between
haplotype embeddings**. At each genomic position, each sample's query vector attends
over all other samples' key vectors:

.. math::

   \alpha_{ij}^\ell = \frac{\exp\!\bigl(\mathbf{q}_{i}^{\ell \top} \mathbf{k}_{j}^{\ell}
   / \sqrt{d}\bigr)}{\sum_{j' \neq i} \exp\!\bigl(\mathbf{q}_{i}^{\ell \top}
   \mathbf{k}_{j'}^{\ell} / \sqrt{d}\bigr)}

The attention weight :math:`\alpha_{ij}^\ell` is the soft analogue of "sample :math:`i`
copies from sample :math:`j` at position :math:`\ell`." To produce discrete parent
assignments (required for a tree topology), we apply **Gumbel-softmax** during
training, which allows gradients to flow through the discrete choice:

.. math::

   \hat{\alpha}_{ij}^\ell = \frac{\exp\!\bigl((\log \alpha_{ij}^\ell + g_{ij}) / \tau\bigr)}
   {\sum_{j'} \exp\!\bigl((\log \alpha_{ij'}^\ell + g_{ij'}) / \tau\bigr)}

where :math:`g_{ij} \sim \text{Gumbel}(0, 1)` and :math:`\tau` is the temperature,
annealed toward zero during training.


Principle 6: SFS as Physics-Informed Regularizer
===================================================

*From* :ref:`dadi <dadi_timepiece>`, :ref:`moments <moments_timepiece>`, *and*
:ref:`momi2 <momi2_timepiece>`

.. admonition:: Principle

   The **site frequency spectrum** (SFS) is a deterministic function of the ARG:
   :math:`\text{SFS}[k] = \sum_e b(e) \cdot \mathbf{1}[\text{descendants}(e) = k]`,
   where :math:`b(e)` is the branch length of edge :math:`e` and :math:`k` is the
   number of descendant leaves. This relationship is differentiable and can be used
   as a **physics-informed loss** -- a hard constraint from population genetics theory
   that the network must satisfy.

**What dadi/moments/momi2 do.** These methods compute the expected SFS under a
demographic model using diffusion equations (:ref:`dadi <dadi_diffusion_equation>`),
moment equations (:ref:`moments <moment_equations>`), or Moran-model tensor algebra
(:ref:`momi2 <tensor_machinery>`). They then fit the model to the observed SFS.

**Realization in Mainspring.** The predicted ARG (topology + node times) implies a
predicted SFS. We compute it differentiably:

.. code-block:: python

   def predicted_sfs(parent_array, node_times, n_leaves):
       """Compute SFS from predicted ARG (differentiable)."""
       n_samples = n_leaves
       sfs = torch.zeros(n_samples + 1)
       for edge_idx in range(len(parent_array)):
           child = edge_idx
           parent = parent_array[edge_idx]
           branch_length = node_times[parent] - node_times[child]
           n_descendants = count_descendants(child, parent_array, n_leaves)
           sfs[n_descendants] += branch_length
       return sfs[1:-1]  # exclude monomorphic classes

The SFS loss is then the :math:`\chi^2` distance (or Poisson log-likelihood) between
the predicted SFS and the true SFS computed from the genotype matrix. This loss does
not require knowing the true ARG -- only the observed allele frequencies -- and acts
as a physics-informed regularizer that guides the network toward ARGs consistent with
the observed diversity pattern.


Principle 7: Gamma-Distributed Branch-Length Posteriors
=========================================================

*From* :ref:`Gamma-SMC <gamma_smc_timepiece>`

.. admonition:: Principle

   Coalescence times are positive, right-skewed random variables. The **gamma
   distribution** is a natural parametric family for their posteriors: it is
   supported on :math:`(0, \infty)`, can represent both peaked and diffuse
   uncertainty, and is closed under the multiplicative updates that arise in
   sequential Bayesian inference.

**What Gamma-SMC does.** Gamma-SMC maintains a gamma-distributed belief over the
coalescence time at each position, updating it as new sites are processed (see
:ref:`gamma_smc_gamma_approximation`). The shape parameter :math:`\alpha` tracks the
"evidence count" and the rate parameter :math:`\beta` tracks the "evidence mass."

**Realization in Mainspring.** The dating GNN's output heads predict **gamma
parameters** :math:`(\alpha_v, \beta_v)` for each node :math:`v`:

.. math::

   \alpha_v = \text{softplus}(\mathbf{w}_\alpha^\top \mathbf{h}_v + b_\alpha) + 1

.. math::

   \beta_v = \text{softplus}(\mathbf{w}_\beta^\top \mathbf{h}_v + b_\beta)

The :math:`+1` in the shape parameter ensures :math:`\alpha_v > 1`, so the gamma
density has a mode at :math:`(\alpha_v - 1)/\beta_v > 0`, preventing degenerate
distributions concentrated at zero. The loss on node times is the negative
log-likelihood under the predicted gamma:

.. math::

   \mathcal{L}_{\text{time}} = -\sum_v \log \text{Gamma}(t_v^* \mid \alpha_v, \beta_v)

where :math:`t_v^*` is the true node time from the simulation.


Principle 8: Randomized Discretization
=========================================

*From* :ref:`phlash <phlash_timepiece>`

.. admonition:: Principle

   Fixed discretization of time creates **aliasing artifacts**: the inferred
   :math:`N_e(t)` is biased toward the grid. **Randomizing the discretization**
   during training -- using different time grids for different training examples --
   forces the network to learn a representation that is robust to the choice of grid.

**What phlash does.** phlash uses random discretization of the coalescent time axis
to avoid the aliasing artifacts that plague PSMC and SMC++ (see
:ref:`phlash_random_discretization`). Each gradient step uses a different random
grid, so the posterior estimate averages over grid artifacts.

**Realization in Mainspring.** We apply randomized discretization in two places:

1. **Random Fourier features for positional encoding.** Instead of fixed sinusoidal
   positional embeddings, we use random Fourier features whose frequency distribution
   is drawn fresh for each training batch. This prevents the network from memorizing
   position-specific artifacts.

2. **Varied attention window sizes.** During training, the sliding-window size
   :math:`w` is sampled uniformly from :math:`[w_{\min}, w_{\max}]` for each batch.
   This forces the encoder to be robust to different effective correlation lengths,
   just as phlash's random discretization forces robustness to different time grids.

.. code-block:: python

   class RandomFourierPositionalEncoding(nn.Module):
       def __init__(self, d_model, sigma=10.0):
           super().__init__()
           self.d_model = d_model
           self.sigma = sigma

       def forward(self, positions):
           B = torch.randn(self.d_model // 2, device=positions.device) * self.sigma
           proj = positions.unsqueeze(-1) * B.unsqueeze(0)
           return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)


Principle 9: Per-Segment Sufficient Statistics as Features
============================================================

*From* :ref:`Threads <threads_timepiece>` *(dating)*

.. admonition:: Principle

   For a genomic segment of length :math:`\ell` bp with :math:`m` mutations, the
   natural estimator of the total branch length is :math:`m / \mu\ell`, and the
   estimator of TMRCA for a pair is :math:`(m + 1) / (\rho + \mu)`. These
   **per-segment sufficient statistics** capture the essential information in each
   genomic region and should be provided as input features to any dating algorithm.

**What Threads does.** Threads dates each edge in a tree sequence using the mutation
count on the edge and the genomic span it covers (see :ref:`dating_threads`). The
estimator :math:`\hat{t} = (m + 1) / (\rho + \mu)` is a simple but effective
natural estimator that combines recombination and mutation evidence.

**Realization in Mainspring.** The dating GNN receives **per-edge features** that
include:

.. list-table::
   :header-rows: 1
   :widths: 30 70

   * - Feature
     - Description
   * - :math:`m_e`
     - Mutation count on edge :math:`e`
   * - :math:`s_e`
     - Genomic span of edge :math:`e` (in base pairs)
   * - :math:`\hat{t}_e = (m_e + 1) / (\hat{\rho} s_e + \hat{\mu} s_e)`
     - Threads-style natural time estimator
   * - :math:`\log m_e`, :math:`\log s_e`
     - Log-transformed versions (for numerical stability)
   * - :math:`n_e`
     - Number of descendant leaves below edge :math:`e`

These features are concatenated and passed through a linear projection to produce
edge embeddings. The natural estimator :math:`\hat{t}_e` is not the final answer --
it is the initial condition for the GNN's message-passing updates, analogous to how
Threads' initial estimates are refined by iterative dating algorithms.


Principle 10: Log-Space Computation and Softplus Positivity
=============================================================

*From* :ref:`PSMC <psmc_timepiece>` *(decoding and scaling)*

.. admonition:: Principle

   Population-genetic quantities span many orders of magnitude: coalescence times
   range from :math:`10^1` to :math:`10^6` generations, population sizes from
   :math:`10^2` to :math:`10^7`. Working in **log-space** prevents numerical
   underflow/overflow and ensures that the network treats a factor-of-two change in
   :math:`N_e` the same whether it occurs at :math:`N_e = 1{,}000` or
   :math:`N_e = 100{,}000`.

**What PSMC does.** PSMC's forward-backward algorithm operates in log-space to
prevent underflow over long genomic sequences (see :ref:`psmc_decoding`). The
time discretization is log-spaced, giving finer resolution in the recent past where
the data is most informative. The output :math:`\lambda(t)` is plotted on a
log-log scale.

**Realization in Mainspring.** We enforce log-scale / positive-quantity conventions
throughout:

- **Node times** are predicted in log-space: the GNN outputs
  :math:`\log \hat{t}_v`, and we exponentiate to get :math:`\hat{t}_v = e^{\log \hat{t}_v}`.
  This ensures positivity and gives the network equal relative precision across
  all time scales.

- **Population sizes** are parameterized through softplus:
  :math:`N_e(t) = \text{softplus}(z(t)) = \log(1 + e^{z(t)})`, where :math:`z(t)`
  is the unconstrained output of the demographic decoder. Softplus is smoother than
  ReLU at zero and better behaved than exponentiation for large values.

- **Branch lengths** in the gamma output heads use log-parameterized rate:
  :math:`\beta_v = \exp(\tilde{\beta}_v)`, ensuring :math:`\beta_v > 0` without
  clamping.

.. code-block:: python

   import torch.nn.functional as F

   def decode_times(log_times_raw):
       """Convert raw network output to positive times."""
       return torch.exp(log_times_raw)

   def decode_population_sizes(z):
       """Convert unconstrained output to positive N_e."""
       return F.softplus(z)

   def decode_gamma_params(alpha_raw, beta_raw):
       """Convert raw outputs to valid gamma parameters."""
       alpha = F.softplus(alpha_raw) + 1.0  # ensure alpha > 1
       beta = torch.exp(beta_raw)           # ensure beta > 0
       return alpha, beta


Summary
========

.. list-table:: Ten design principles and their origins
   :header-rows: 1
   :widths: 5 20 35 40

   * - #
     - Source Timepiece
     - Principle
     - Neural realization
   * - 1
     - :ref:`PSMC <psmc_timepiece>`
     - Sequential Markov structure
     - Sliding-window causal attention
   * - 2
     - :ref:`SMC++ <smcpp_timepiece>`
     - Permutation invariance over samples
     - Set Transformer (ISAB)
   * - 3
     - :ref:`ARGweaver <argweaver_timepiece>` / :ref:`SINGER <singer_timepiece>`
     - Full ARG is sufficient statistic
     - Output full ARG in tskit format
   * - 4
     - :ref:`tsdate <tsdate_timepiece>`
     - Message passing for dating
     - GNN on local trees
   * - 5
     - :ref:`tsinfer <tsinfer_timepiece>` / :ref:`Threads <threads_timepiece>`
     - Li & Stephens = attention
     - Cross-attention + Gumbel-softmax
   * - 6
     - :ref:`dadi <dadi_timepiece>` / :ref:`moments <moments_timepiece>` /
       :ref:`momi2 <momi2_timepiece>`
     - SFS as auxiliary loss
     - Differentiable SFS computation
   * - 7
     - :ref:`Gamma-SMC <gamma_smc_timepiece>`
     - Gamma posteriors for branch lengths
     - Gamma output heads on GNN
   * - 8
     - :ref:`phlash <phlash_timepiece>`
     - Randomized discretization
     - Random Fourier features, varied window sizes
   * - 9
     - :ref:`Threads <threads_timepiece>` (dating)
     - Per-segment sufficient statistics
     - Edge features: :math:`m_e, s_e, \hat{t}_e`
   * - 10
     - :ref:`PSMC <psmc_timepiece>` (decoding)
     - Log-space, softplus positivity
     - Log-time prediction, softplus :math:`N_e`

These principles are not post-hoc rationalizations. They were the starting point of
the design process: we asked, for each Timepiece, "what is the one structural insight
that makes this method work?" and then found the neural analogue. The next chapter
assembles these principles into a complete architecture.
