Complication I: Mainspring

Amortized ARG Inference via Structured Neural Posterior Estimation

The Mechanism at a Glance

Every method in this book makes the same fundamental trade: mathematical tractability vs. biological realism. PSMC discretizes time and assumes piecewise-constant demography. dadi collapses the genome to a frequency histogram. ARGweaver is exact under the DSMC model but costs \(O(S^2)\) per site. tsinfer scales to millions of samples but surrenders posterior inference entirely. Each Timepiece chooses a different point on the Pareto frontier between accuracy and compute.

Deep learning can break this frontier. Not by replacing the gear train – the likelihood machinery the book spends hundreds of pages building – but by learning to shortcut it. Simulate millions of ARGs from the generative model (msprime). Train a neural network to invert the simulation: map observed sequences back to the ARG and the demography that produced them. At inference time, a single forward pass replaces hours of MCMC or EM.

The question is: what architecture respects the mathematical structure of the problem well enough to learn efficiently? This Complication answers that question by distilling one design principle from each Timepiece.

The name

The mainspring is the power source of a mechanical watch – a coiled spring that stores energy and releases it through the gear train to drive the hands. In Mainspring, simulated ARGs are the stored energy, and the neural network is the gear train that converts them into inference power. Like a physical mainspring, the energy is wound in advance (during training) and released in precisely metered pulses (at inference time).

The four stages of Mainspring:

  1. The Genomic Encoder (the escapement) – A Transformer that processes the genotype matrix with sliding-window attention over positions (from PSMC’s sequential Markov property) and Set Transformer attention over samples (from SMC++’s permutation invariance). Output: per-sample, per-position latent vectors.

  2. The Topology Decoder (the gear train) – A learned Li & Stephens model. Cross-attention between haplotypes identifies who is copying whom at each position (from tsinfer/lshmm). Hard attention via Gumbel-softmax yields discrete parent assignments. A breakpoint detector identifies where trees change. Output: a tree sequence topology.

  3. The Dating GNN (the mainspring) – A Graph Neural Network that runs learned message-passing on each local tree (from tsdate’s inside-outside algorithm). Edge features include mutation count and genomic span (from Threads’ sufficient statistics). Output: gamma-distributed node times (from Gamma-SMC).

  4. The Demographic Decoder (the case and dial) – A conditional normalizing flow that maps the inferred coalescence-time distribution to a posterior over continuous \(N_e(t)\) functions (from phlash). An SFS auxiliary loss provides physics-informed regularization (from dadi/moments/momi2).

Genotype matrix D ∈ {0,1}^{n × L}
                   |
                   v
         +--------------------------+
         |  GENOMIC ENCODER         |
         |  Sliding-window attn     |
         |  (PSMC: sequential SMC)  |
         |  Set Transformer         |
         |  (SMC++: permutation eq) |
         +--------------------------+
                   |
                   v
         +--------------------------+
         |  TOPOLOGY DECODER        |
         |  Cross-attention         |
         |  (tsinfer: copying model)|
         |  Gumbel-softmax → edges  |
         |  Breakpoint detection    |
         +--------------------------+
                   |
                   v
         +--------------------------+
         |  DATING GNN              |
         |  Message passing on trees|
         |  (tsdate: inside-outside)|
         |  Gamma(α, β) output      |
         |  (Gamma-SMC: posteriors) |
         +--------------------------+
                   |
                   v
         +--------------------------+
         |  DEMOGRAPHIC DECODER     |
         |  Normalizing flow → N_e  |
         |  (phlash: continuous)    |
         |  SFS auxiliary loss      |
         |  (dadi/moments: physics) |
         +--------------------------+
                   |
                   v
         Full dated ARG + N_e(t)

Prerequisites for this Complication

Before starting Mainspring, you should have worked through:

  • PSMC – the sequential Markov property and HMM inference

  • tsinfer – the Li & Stephens copying model and tree sequence representation

  • tsdate – the inside-outside algorithm on trees and variational gamma posteriors

  • The SMC – the sequential Markov coalescent approximation

  • msprime – the coalescent simulator (for generating training data)

Familiarity with Transformer architectures and PyTorch is assumed.

Chapters