Design Principles – One Per Timepiece

A complication is not an ornament. Every additional mechanism must earn its place by solving a problem that simpler mechanisms cannot.

Mainspring’s architecture is not arbitrary. Every major design choice can be traced back to a specific mathematical insight from a specific Timepiece. This chapter catalogues ten such principles – one from each Timepiece in the book – and explains how each is realized in the neural network.

The principles are not optional features. Remove any one and the network’s performance degrades measurably. Together they form the inductive bias that allows Mainspring to learn from millions of simulations what would take billions without structure.

Principle 1: Sequential Markov Structure

From PSMC

Principle

Along the genome, the coalescence time at position \(\ell + 1\) depends on the coalescence time at position \(\ell\) but not on positions \(1, \ldots, \ell - 1\). This is the sequential Markov property – the foundation of every SMC-based method.

What PSMC does. PSMC models the sequence of coalescence times as a hidden Markov chain along the genome. The transition matrix encodes the probability that a recombination event changes the genealogy between adjacent genomic bins, and the probability of re-coalescence at each possible time (see The Continuous-Time PSMC Model).

Realization in Mainspring. The genomic encoder uses sliding-window causal attention. Each genomic position attends only to positions within a window of \(w\) sites (default \(w = 512\)), and attention is masked so that position \(\ell\) cannot attend to positions beyond \(\ell + w\). This is the neural analogue of the Markov property: information propagates locally along the genome, respecting the fact that recombination decorrelates genealogies over distance.

\[\begin{split}\text{Attention}(\ell, \ell') = \begin{cases} \text{softmax}\!\bigl(\mathbf{q}_\ell^\top \mathbf{k}_{\ell'} / \sqrt{d}\bigr) & \text{if } |\ell - \ell'| \leq w \\ 0 & \text{otherwise} \end{cases}\end{split}\]

The window size \(w\) is a hyperparameter analogous to the correlation length in the SMC: it should be large enough to capture the typical distance between recombination breakpoints. In practice, we set \(w\) to cover roughly \(2/\rho\) sites, where \(\rho\) is the per-site recombination rate.

import torch
import torch.nn as nn

class SlidingWindowAttention(nn.Module):
    def __init__(self, d_model, n_heads, window_size):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.window_size = window_size

    def forward(self, x):
        L = x.size(1)
        mask = torch.ones(L, L, dtype=torch.bool, device=x.device)
        for i in range(L):
            lo = max(0, i - self.window_size)
            hi = min(L, i + self.window_size + 1)
            mask[i, lo:hi] = False
        out, _ = self.attn(x, x, x, attn_mask=mask)
        return out

Principle 2: Permutation Invariance over Samples

From SMC++

Principle

Under the coalescent, the labels of the \(n\) sampled haplotypes are exchangeable: the likelihood of a genotype matrix is invariant to reordering the rows. Any architecture that processes samples must respect this symmetry.

What SMC++ does. SMC++ handles multiple samples by designating one distinguished lineage and treating the remaining \(n - 1\) samples symmetrically (see The Distinguished Lineage). The undistinguished lineages contribute through a permutation-invariant sufficient statistic: at each site, only the allele frequency among the undistinguished samples matters, not which specific samples carry the derived allele.

Realization in Mainspring. The genomic encoder applies a Set Transformer (Lee et al. 2019) over the sample dimension. At each genomic position, the \(n\) sample embeddings are processed by induced set attention blocks (ISAB) that are exactly equivariant to permutations of the input rows.

\[\mathbf{H}_\ell = \text{SetTransformer}\!\bigl(\mathbf{d}_{1,\ell},\; \mathbf{d}_{2,\ell},\; \ldots,\; \mathbf{d}_{n,\ell}\bigr)\]

where \(\mathbf{d}_{i,\ell}\) is the embedding of sample \(i\) at position \(\ell\), and \(\mathbf{H}_\ell \in \mathbb{R}^{n \times d}\) is the output embedding matrix. Crucially, permuting the rows of the input permutes the rows of the output in the same way – the network cannot distinguish sample 1 from sample 7 unless the data itself distinguishes them.

class InducedSetAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_inducing):
        super().__init__()
        self.inducing = nn.Parameter(torch.randn(1, n_inducing, d_model))
        self.attn_to_inducing = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.attn_from_inducing = nn.MultiheadAttention(d_model, n_heads, batch_first=True)

    def forward(self, x):
        B = x.size(0)
        ind = self.inducing.expand(B, -1, -1)
        h, _ = self.attn_to_inducing(ind, x, x)
        out, _ = self.attn_from_inducing(x, h, h)
        return out

Principle 3: The Full ARG as Sufficient Statistic

From ARGweaver and SINGER

Principle

The ancestral recombination graph is a sufficient statistic for the data under the coalescent-with-recombination model. If you know the true ARG, no additional information in the genotype matrix can improve your inference of demographic parameters.

What ARGweaver and SINGER do. Both methods sample from the posterior distribution over ARGs given the data. ARGweaver uses a threading-based MCMC scheme (see MCMC Sampling); SINGER uses Gibbs sampling with a Gaussian process prior on branch lengths (see Branch Sampling). Both produce full ARGs – topology, breakpoints, and node times – as output.

Realization in Mainspring. Rather than predicting scalar summary statistics, Mainspring’s topology decoder and dating GNN jointly produce a full ARG in tskit format. The network outputs:

  • A sequence of local tree topologies (parent arrays)

  • Breakpoint positions where trees change

  • Node times for every internal node in every local tree

This is a much harder prediction target than a scalar, but it is the right one: if the ARG is the sufficient statistic, then a network that predicts the ARG captures all the information in the data. The demographic decoder then operates on the predicted ARG, not on the raw genotype matrix, mirroring the factorization

\[p(N_e \mid \mathbf{D}) = \int p(N_e \mid \mathcal{A})\, p(\mathcal{A} \mid \mathbf{D})\, d\mathcal{A}\]

where \(\mathcal{A}\) is the ARG and \(\mathbf{D}\) is the data.

Principle 4: Message Passing for Dating

From tsdate

Principle

Given a tree topology and mutation counts on edges, node times can be inferred by message passing – propagating information upward from leaves (inside pass) and downward from the root (outside pass). This is belief propagation on a tree-shaped graphical model.

What tsdate does. tsdate runs an inside-outside algorithm on each local tree in a tree sequence. The inside pass aggregates mutation evidence from leaves toward the root; the outside pass propagates coalescent-prior information from the root toward the leaves. The result is a posterior distribution over node times that combines data (mutations) with prior (coalescent) (see Inside-Outside Belief Propagation).

Realization in Mainspring. The dating GNN implements a learned message-passing scheme on each local tree. Each node \(v\) has a feature vector \(\mathbf{h}_v\) that is updated by aggregating messages from its parent and children:

\[\mathbf{h}_v^{(k+1)} = \text{Update}\!\Bigl(\mathbf{h}_v^{(k)},\; \text{Agg}\bigl(\{\mathbf{m}_{u \to v}^{(k)} : u \in \mathcal{N}(v)\}\bigr)\Bigr)\]

where \(\mathcal{N}(v)\) includes the parent and children of \(v\), and the message function \(\mathbf{m}_{u \to v}\) depends on the edge features (mutation count, genomic span) between \(u\) and \(v\). After \(K\) rounds of message passing, the final node features are decoded into gamma-distributed time posteriors.

The key difference from tsdate is that the message and update functions are learned, not derived from the coalescent likelihood. This allows the GNN to capture correlations that tsdate’s factored approximation misses – for example, the constraint that a parent node must be older than all its children, which tsdate enforces post-hoc but which the GNN can learn to respect automatically.

Principle 5: Li & Stephens as Attention

From tsinfer and Threads

Principle

The Li & Stephens model says that each haplotype is an imperfect mosaic copy of other haplotypes. The pattern of “who copies whom” at each position defines the local tree topology. This copying relationship is formally equivalent to an attention mechanism where each haplotype attends to the haplotype it is copying.

What tsinfer does. tsinfer matches each sample haplotype against a panel of ancestral haplotypes using the Li & Stephens HMM (see Gear 2: The Copying Model). At each site, the Viterbi path identifies which ancestor the sample is copying. A recombination event corresponds to a switch in the copying source – a change in attention.

Realization in Mainspring. The topology decoder uses cross-attention between haplotype embeddings. At each genomic position, each sample’s query vector attends over all other samples’ key vectors:

\[\alpha_{ij}^\ell = \frac{\exp\!\bigl(\mathbf{q}_{i}^{\ell \top} \mathbf{k}_{j}^{\ell} / \sqrt{d}\bigr)}{\sum_{j' \neq i} \exp\!\bigl(\mathbf{q}_{i}^{\ell \top} \mathbf{k}_{j'}^{\ell} / \sqrt{d}\bigr)}\]

The attention weight \(\alpha_{ij}^\ell\) is the soft analogue of “sample \(i\) copies from sample \(j\) at position \(\ell\).” To produce discrete parent assignments (required for a tree topology), we apply Gumbel-softmax during training, which allows gradients to flow through the discrete choice:

\[\hat{\alpha}_{ij}^\ell = \frac{\exp\!\bigl((\log \alpha_{ij}^\ell + g_{ij}) / \tau\bigr)} {\sum_{j'} \exp\!\bigl((\log \alpha_{ij'}^\ell + g_{ij'}) / \tau\bigr)}\]

where \(g_{ij} \sim \text{Gumbel}(0, 1)\) and \(\tau\) is the temperature, annealed toward zero during training.

Principle 6: SFS as Physics-Informed Regularizer

From dadi, moments, and momi2

Principle

The site frequency spectrum (SFS) is a deterministic function of the ARG: \(\text{SFS}[k] = \sum_e b(e) \cdot \mathbf{1}[\text{descendants}(e) = k]\), where \(b(e)\) is the branch length of edge \(e\) and \(k\) is the number of descendant leaves. This relationship is differentiable and can be used as a physics-informed loss – a hard constraint from population genetics theory that the network must satisfy.

What dadi/moments/momi2 do. These methods compute the expected SFS under a demographic model using diffusion equations (dadi), moment equations (moments), or Moran-model tensor algebra (momi2). They then fit the model to the observed SFS.

Realization in Mainspring. The predicted ARG (topology + node times) implies a predicted SFS. We compute it differentiably:

def predicted_sfs(parent_array, node_times, n_leaves):
    """Compute SFS from predicted ARG (differentiable)."""
    n_samples = n_leaves
    sfs = torch.zeros(n_samples + 1)
    for edge_idx in range(len(parent_array)):
        child = edge_idx
        parent = parent_array[edge_idx]
        branch_length = node_times[parent] - node_times[child]
        n_descendants = count_descendants(child, parent_array, n_leaves)
        sfs[n_descendants] += branch_length
    return sfs[1:-1]  # exclude monomorphic classes

The SFS loss is then the \(\chi^2\) distance (or Poisson log-likelihood) between the predicted SFS and the true SFS computed from the genotype matrix. This loss does not require knowing the true ARG – only the observed allele frequencies – and acts as a physics-informed regularizer that guides the network toward ARGs consistent with the observed diversity pattern.

Principle 7: Gamma-Distributed Branch-Length Posteriors

From Gamma-SMC

Principle

Coalescence times are positive, right-skewed random variables. The gamma distribution is a natural parametric family for their posteriors: it is supported on \((0, \infty)\), can represent both peaked and diffuse uncertainty, and is closed under the multiplicative updates that arise in sequential Bayesian inference.

What Gamma-SMC does. Gamma-SMC maintains a gamma-distributed belief over the coalescence time at each position, updating it as new sites are processed (see The Gamma Approximation). The shape parameter \(\alpha\) tracks the “evidence count” and the rate parameter \(\beta\) tracks the “evidence mass.”

Realization in Mainspring. The dating GNN’s output heads predict gamma parameters \((\alpha_v, \beta_v)\) for each node \(v\):

\[\alpha_v = \text{softplus}(\mathbf{w}_\alpha^\top \mathbf{h}_v + b_\alpha) + 1\]
\[\beta_v = \text{softplus}(\mathbf{w}_\beta^\top \mathbf{h}_v + b_\beta)\]

The \(+1\) in the shape parameter ensures \(\alpha_v > 1\), so the gamma density has a mode at \((\alpha_v - 1)/\beta_v > 0\), preventing degenerate distributions concentrated at zero. The loss on node times is the negative log-likelihood under the predicted gamma:

\[\mathcal{L}_{\text{time}} = -\sum_v \log \text{Gamma}(t_v^* \mid \alpha_v, \beta_v)\]

where \(t_v^*\) is the true node time from the simulation.

Principle 8: Randomized Discretization

From phlash

Principle

Fixed discretization of time creates aliasing artifacts: the inferred \(N_e(t)\) is biased toward the grid. Randomizing the discretization during training – using different time grids for different training examples – forces the network to learn a representation that is robust to the choice of grid.

What phlash does. phlash uses random discretization of the coalescent time axis to avoid the aliasing artifacts that plague PSMC and SMC++ (see Random Time Discretization). Each gradient step uses a different random grid, so the posterior estimate averages over grid artifacts.

Realization in Mainspring. We apply randomized discretization in two places:

  1. Random Fourier features for positional encoding. Instead of fixed sinusoidal positional embeddings, we use random Fourier features whose frequency distribution is drawn fresh for each training batch. This prevents the network from memorizing position-specific artifacts.

  2. Varied attention window sizes. During training, the sliding-window size \(w\) is sampled uniformly from \([w_{\min}, w_{\max}]\) for each batch. This forces the encoder to be robust to different effective correlation lengths, just as phlash’s random discretization forces robustness to different time grids.

class RandomFourierPositionalEncoding(nn.Module):
    def __init__(self, d_model, sigma=10.0):
        super().__init__()
        self.d_model = d_model
        self.sigma = sigma

    def forward(self, positions):
        B = torch.randn(self.d_model // 2, device=positions.device) * self.sigma
        proj = positions.unsqueeze(-1) * B.unsqueeze(0)
        return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)

Principle 9: Per-Segment Sufficient Statistics as Features

From Threads (dating)

Principle

For a genomic segment of length \(\ell\) bp with \(m\) mutations, the natural estimator of the total branch length is \(m / \mu\ell\), and the estimator of TMRCA for a pair is \((m + 1) / (\rho + \mu)\). These per-segment sufficient statistics capture the essential information in each genomic region and should be provided as input features to any dating algorithm.

What Threads does. Threads dates each edge in a tree sequence using the mutation count on the edge and the genomic span it covers (see Dating Path Segments). The estimator \(\hat{t} = (m + 1) / (\rho + \mu)\) is a simple but effective natural estimator that combines recombination and mutation evidence.

Realization in Mainspring. The dating GNN receives per-edge features that include:

Feature

Description

\(m_e\)

Mutation count on edge \(e\)

\(s_e\)

Genomic span of edge \(e\) (in base pairs)

\(\hat{t}_e = (m_e + 1) / (\hat{\rho} s_e + \hat{\mu} s_e)\)

Threads-style natural time estimator

\(\log m_e\), \(\log s_e\)

Log-transformed versions (for numerical stability)

\(n_e\)

Number of descendant leaves below edge \(e\)

These features are concatenated and passed through a linear projection to produce edge embeddings. The natural estimator \(\hat{t}_e\) is not the final answer – it is the initial condition for the GNN’s message-passing updates, analogous to how Threads’ initial estimates are refined by iterative dating algorithms.

Principle 10: Log-Space Computation and Softplus Positivity

From PSMC (decoding and scaling)

Principle

Population-genetic quantities span many orders of magnitude: coalescence times range from \(10^1\) to \(10^6\) generations, population sizes from \(10^2\) to \(10^7\). Working in log-space prevents numerical underflow/overflow and ensures that the network treats a factor-of-two change in \(N_e\) the same whether it occurs at \(N_e = 1{,}000\) or \(N_e = 100{,}000\).

What PSMC does. PSMC’s forward-backward algorithm operates in log-space to prevent underflow over long genomic sequences (see Decoding the Clock). The time discretization is log-spaced, giving finer resolution in the recent past where the data is most informative. The output \(\lambda(t)\) is plotted on a log-log scale.

Realization in Mainspring. We enforce log-scale / positive-quantity conventions throughout:

  • Node times are predicted in log-space: the GNN outputs \(\log \hat{t}_v\), and we exponentiate to get \(\hat{t}_v = e^{\log \hat{t}_v}\). This ensures positivity and gives the network equal relative precision across all time scales.

  • Population sizes are parameterized through softplus: \(N_e(t) = \text{softplus}(z(t)) = \log(1 + e^{z(t)})\), where \(z(t)\) is the unconstrained output of the demographic decoder. Softplus is smoother than ReLU at zero and better behaved than exponentiation for large values.

  • Branch lengths in the gamma output heads use log-parameterized rate: \(\beta_v = \exp(\tilde{\beta}_v)\), ensuring \(\beta_v > 0\) without clamping.

import torch.nn.functional as F

def decode_times(log_times_raw):
    """Convert raw network output to positive times."""
    return torch.exp(log_times_raw)

def decode_population_sizes(z):
    """Convert unconstrained output to positive N_e."""
    return F.softplus(z)

def decode_gamma_params(alpha_raw, beta_raw):
    """Convert raw outputs to valid gamma parameters."""
    alpha = F.softplus(alpha_raw) + 1.0  # ensure alpha > 1
    beta = torch.exp(beta_raw)           # ensure beta > 0
    return alpha, beta

Summary

Ten design principles and their origins

#

Source Timepiece

Principle

Neural realization

1

PSMC

Sequential Markov structure

Sliding-window causal attention

2

SMC++

Permutation invariance over samples

Set Transformer (ISAB)

3

ARGweaver / SINGER

Full ARG is sufficient statistic

Output full ARG in tskit format

4

tsdate

Message passing for dating

GNN on local trees

5

tsinfer / Threads

Li & Stephens = attention

Cross-attention + Gumbel-softmax

6

dadi / moments / momi2

SFS as auxiliary loss

Differentiable SFS computation

7

Gamma-SMC

Gamma posteriors for branch lengths

Gamma output heads on GNN

8

phlash

Randomized discretization

Random Fourier features, varied window sizes

9

Threads (dating)

Per-segment sufficient statistics

Edge features: \(m_e, s_e, \hat{t}_e\)

10

PSMC (decoding)

Log-space, softplus positivity

Log-time prediction, softplus \(N_e\)

These principles are not post-hoc rationalizations. They were the starting point of the design process: we asked, for each Timepiece, “what is the one structural insight that makes this method work?” and then found the neural analogue. The next chapter assembles these principles into a complete architecture.