Training
A mainspring must be wound carefully. Too little tension and the watch stops. Too much and the spring breaks. The art is in the winding – gradually, evenly, with increasing force.
Training Mainspring requires three ingredients: a simulation engine that generates (genotype matrix, true ARG, true demography) triples, a composite loss function that scores each component of the prediction, and a curriculum that gradually increases the complexity of the training distribution. This chapter builds all three.
The Simulation Engine
Mainspring learns by inverting simulations. The simulator is the generative model; the network learns to be its approximate inverse. The quality of the training data determines the ceiling of inference quality – no network can learn to infer features absent from the training simulations.
We use msprime as the simulation engine. Each training example is generated by:
Sampling a demographic model from a prior distribution.
Simulating a tree sequence under that demographic model.
Extracting the genotype matrix, the true ARG, and the true demography.
import msprime
import numpy as np
def sample_demography(rng):
"""Sample a random demographic model from the training prior.
Returns a msprime.Demography object and a callable N_e(t).
"""
n_epochs = rng.integers(1, 8)
times = np.sort(rng.exponential(scale=5000, size=n_epochs))
times = np.concatenate([[0], times])
sizes = 10 ** rng.uniform(2, 5, size=n_epochs + 1)
demography = msprime.Demography()
demography.add_population(name="pop", initial_size=sizes[0])
for i in range(1, len(times)):
demography.add_population_parameters_change(
time=times[i], initial_size=sizes[i], population="pop"
)
def ne_func(t):
idx = np.searchsorted(times, t, side='right') - 1
return sizes[idx]
return demography, ne_func, times, sizes
def simulate_training_example(n_samples, seq_length, mu, rho, rng):
"""Generate one (genotype_matrix, true_ARG, true_demography) triple."""
demography, ne_func, times, sizes = sample_demography(rng)
ts = msprime.sim_ancestry(
samples=n_samples,
sequence_length=seq_length,
recombination_rate=rho,
demography=demography,
random_seed=rng.integers(1, 2**31),
)
ts = msprime.sim_mutations(ts, rate=mu, random_seed=rng.integers(1, 2**31))
genotype_matrix = ts.genotype_matrix().T # (n_samples, n_sites)
return {
'genotypes': genotype_matrix,
'tree_sequence': ts,
'ne_func': ne_func,
'ne_times': times,
'ne_sizes': sizes,
}
Extracting training targets from the tree sequence
The tree sequence ts returned by msprime contains everything we need:
Topology targets: For each local tree, the parent array
tree.parent_arraygives the true topology.Breakpoint targets:
ts.breakpoints()gives the true positions where trees change.Node time targets:
ts.tables.nodes.timegives the true time of every node.Demography targets: The
ne_funccallable gives the true \(N_e(t)\) at any time.
The genotype matrix is the input; everything else is a training target.
Scaling the Simulation Pipeline
Training requires millions of simulated datasets. Generating them on-the-fly (one per gradient step) is essential to avoid overfitting to a finite training set. A typical training configuration:
Parameter |
Value |
Rationale |
|---|---|---|
\(n\) (samples) |
20–100 |
Covers typical sample sizes for ARG inference |
\(L\) (sequence length) |
50 kb – 1 Mb |
Covers gene-scale to chromosome-arm-scale |
\(\mu\) (mutation rate) |
\(1.25 \times 10^{-8}\) / bp / gen |
Human mutation rate |
\(\rho\) (recombination rate) |
\(1.0 \times 10^{-8}\) / bp / gen |
Human recombination rate |
\(N_e\) range |
\(10^2\) – \(10^5\) |
Covers bottlenecks through large populations |
Number of epochs |
1–7 |
Covers constant through complex demography |
Simulations per GPU-hour |
~10,000 (100 kb, 50 samples) |
msprime is fast; I/O is the bottleneck |
from torch.utils.data import IterableDataset
class MsprimeDataset(IterableDataset):
def __init__(self, n_samples, seq_length, mu, rho):
self.n_samples = n_samples
self.seq_length = seq_length
self.mu = mu
self.rho = rho
def __iter__(self):
rng = np.random.default_rng()
while True:
example = simulate_training_example(
self.n_samples, self.seq_length, self.mu, self.rho, rng
)
yield self.tensorize(example)
def tensorize(self, example):
import torch
ts = example['tree_sequence']
return {
'genotypes': torch.tensor(example['genotypes'], dtype=torch.float32),
'node_times': torch.tensor(ts.tables.nodes.time, dtype=torch.float32),
'ne_sizes': torch.tensor(example['ne_sizes'], dtype=torch.float32),
'ne_times': torch.tensor(example['ne_times'], dtype=torch.float32),
}
The Loss Function
Mainspring’s loss is a weighted sum of four components, each corresponding to a stage of the architecture:
Component |
Symbol |
What it measures |
Stage |
|---|---|---|---|
Topology loss |
\(\mathcal{L}_{\text{topology}}\) |
Cross-entropy between predicted and true parent assignments + binary cross-entropy for breakpoints |
Stage 2 (Topology Decoder) |
Time loss |
\(\mathcal{L}_{\text{time}}\) |
Negative log-likelihood of true node times under predicted gamma distributions |
Stage 3 (Dating GNN) |
SFS loss |
\(\mathcal{L}_{\text{SFS}}\) |
\(\chi^2\) distance between predicted and observed SFS |
Stages 2+3 (physics regularizer) |
Demographic loss |
\(\mathcal{L}_{\text{demo}}\) |
Negative log-likelihood of true \(N_e(t)\) under the normalizing flow posterior + KL penalty |
Stage 4 (Demographic Decoder) |
The weights \(\lambda_{\text{time}}\), \(\lambda_{\text{SFS}}\), and \(\lambda_{\text{demo}}\) balance the loss components. They are adjusted during curriculum training (see below).
Topology Loss
The topology loss has two terms:
where \(\pi_i^*(\ell)\) is the true parent of sample \(i\) in the local tree at position \(\ell\), \(\alpha_{ij}^\ell\) is the predicted attention weight (copying probability), \(\hat{b}\) is the predicted breakpoint vector, and \(b^*\) is the true breakpoint indicator.
Time Loss
The time loss is the negative log-likelihood of the true node times under the predicted gamma distributions:
where \(\mathcal{V}_{\text{int}}\) is the set of internal nodes and \(t_v^*\) is the true time of node \(v\).
The SFS Loss as a Differentiable Physics Regularizer
The SFS loss deserves special attention because it is the key connection between the Timepieces that operate on summary statistics (dadi, moments, momi2) and the full ARG-based inference of Mainspring.
The SFS is a deterministic function of the ARG. For a sample of \(n\) haplotypes:
where \(b(e)\) is the branch length (in generations) of edge \(e\), \(\text{desc}(e)\) is the number of descendant leaves below edge \(e\), and \(\mu\) is the per-generation mutation rate. This formula says: the expected number of sites with derived allele frequency \(k/n\) equals the total branch length subtending exactly \(k\) leaves, times the mutation rate.
This relationship is differentiable with respect to the predicted branch lengths. If the network predicts node times \(\hat{t}_v\) (from the dating GNN), the branch length of edge \((u, v)\) is \(\hat{b}_{uv} = \hat{t}_u - \hat{t}_v\), and we can compute:
def differentiable_sfs(node_times, parent_array, n_leaves, mu, span):
"""Compute the expected SFS from predicted node times.
This is differentiable w.r.t. node_times, allowing gradient flow
from the SFS loss back through the dating GNN.
"""
n = n_leaves
sfs = torch.zeros(n + 1, dtype=node_times.dtype, device=node_times.device)
for child in range(len(parent_array)):
parent = parent_array[child]
if parent < 0:
continue
branch_length = node_times[parent] - node_times[child]
n_desc = count_descendants_below(child, parent_array, n_leaves)
sfs[n_desc] = sfs[n_desc] + branch_length * mu * span
return sfs[1:n] # SFS[1] through SFS[n-1]
The SFS loss is then:
This is a \(\chi^2\)-type loss that down-weights rare frequency classes (where the observed SFS may be zero or very small). The \(\epsilon\) floor prevents division by zero.
Why this works as a regularizer
The SFS loss does not require knowing the true ARG. It compares the SFS implied by the predicted ARG to the SFS computed directly from the observed genotype matrix (which is always available). This provides a supervision signal that is independent of the topology and time losses – it catches global errors that per-node losses miss. For example, if the network systematically under-estimates deep coalescence times, the predicted SFS will have too few singletons (because deep branches subtend many descendants, shifting weight from low to high frequency classes). The SFS loss detects and corrects this.
Demographic Loss
The demographic loss trains the normalizing flow to produce accurate posterior distributions over \(N_e(t)\). For each training example, we have the true \(N_e(t)\) trajectory sampled from our prior. The loss is the negative log-likelihood of the true trajectory under the flow:
where \(\mathbf{z}^* = g_\phi^{-1}(N_e^*)\) is the true trajectory mapped back to the base distribution, and the second term is the log-determinant of the Jacobian of the inverse flow.
Curriculum Training
Training Mainspring on the full prior from the start is inefficient. Complex demographic models produce ARGs with deep coalescence events, many breakpoints, and wide variation in branch lengths – all of which are difficult for an untrained network to predict. Instead, we use curriculum training: a sequence of phases that gradually increase the complexity of the training distribution.
Phase |
Demography |
Focus |
Loss weights |
Duration |
|---|---|---|---|---|
1 |
Constant \(N_e\) |
Topology and basic dating |
\(\lambda_{\text{time}}=1, \lambda_{\text{SFS}}=0.1, \lambda_{\text{demo}}=0\) |
100k steps |
2 |
1–2 size changes |
Dating under variable \(N_e\) |
\(\lambda_{\text{time}}=1, \lambda_{\text{SFS}}=0.5, \lambda_{\text{demo}}=0.1\) |
200k steps |
3 |
Full prior (1–7 epochs) |
Demographic inference |
\(\lambda_{\text{time}}=1, \lambda_{\text{SFS}}=1, \lambda_{\text{demo}}=1\) |
500k steps |
4 |
Complex + selection (SLiM) |
Robustness to model misspecification |
\(\lambda_{\text{time}}=1, \lambda_{\text{SFS}}=1, \lambda_{\text{demo}}=1\) |
200k steps |
Phase 1: Constant demography. All simulations use a single constant population size \(N_e \sim 10^{\mathcal{U}(2,5)}\). The Gumbel-softmax temperature starts high (\(\tau = 5\)) and is annealed to \(\tau = 1\) by the end of Phase 1. The network learns basic topology reconstruction and time estimation without needing to handle demographic variation.
Phase 2: Simple demographic changes. Simulations include one or two step changes in population size. The demographic decoder is activated (\(\lambda_{\text{demo}} > 0\)), and the network begins learning to map coalescence-time distributions to \(N_e(t)\). The SFS loss weight increases to \(0.5\), strengthening the physics regularizer as the ARGs become more complex.
Phase 3: Full complexity. The training prior covers the full range of demographic models (1–7 epochs, arbitrary population sizes). All loss weights are set to 1. The Gumbel-softmax temperature continues annealing toward \(\tau = 0.1\). This is the longest phase and where most of the learning happens.
Phase 4: Robustness (optional). Training examples include simulations from SLiM (which can model natural selection, population structure, and other complexities not available in msprime). The network sees data generated under model misspecification – the true generative model is more complex than the coalescent assumed by the architecture. This teaches the network to degrade gracefully rather than produce confidently wrong answers.
Gumbel-Softmax Annealing
The temperature \(\tau\) of the Gumbel-softmax in the topology decoder follows an exponential annealing schedule:
where \(t\) is the training step, \(T\) is the total number of steps, and typically \(\tau_{\max} = 5.0\), \(\tau_{\min} = 0.05\).
def anneal_temperature(step, total_steps, tau_max=5.0, tau_min=0.05):
return tau_max * (tau_min / tau_max) ** (step / total_steps)
Training Pseudocode
The complete training loop:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
def train_mainspring(model, n_steps=1_000_000, batch_size=16,
lr=3e-4, device='cuda'):
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=n_steps)
dataset = MsprimeDataset(n_samples=50, seq_length=100_000,
mu=1.25e-8, rho=1.0e-8)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
phase_boundaries = [100_000, 300_000, 800_000, 1_000_000]
lambda_configs = [
{'time': 1.0, 'sfs': 0.1, 'demo': 0.0},
{'time': 1.0, 'sfs': 0.5, 'demo': 0.1},
{'time': 1.0, 'sfs': 1.0, 'demo': 1.0},
{'time': 1.0, 'sfs': 1.0, 'demo': 1.0},
]
model.to(device)
model.train()
for step, batch in enumerate(loader):
if step >= n_steps:
break
phase = sum(step >= b for b in phase_boundaries[:-1])
lambdas = lambda_configs[phase]
model.topology_decoder.tau = anneal_temperature(step, n_steps)
genotypes = batch['genotypes'].to(device)
true_times = batch['node_times'].to(device)
true_ne = batch['ne_sizes'].to(device)
outputs = model(genotypes)
L_topo = topology_loss(outputs['topology'], batch, genotypes)
L_time = time_loss(outputs['alpha'], outputs['beta'], true_times)
L_sfs = sfs_loss(outputs['predicted_sfs'], genotypes)
L_demo = demo_loss(outputs['ne_posterior'], outputs['flow_log_det'],
true_ne)
loss = (L_topo
+ lambdas['time'] * L_time
+ lambdas['sfs'] * L_sfs
+ lambdas['demo'] * L_demo)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
if step % 1000 == 0:
print(f"Step {step}: loss={loss.item():.4f} "
f"topo={L_topo.item():.4f} time={L_time.item():.4f} "
f"sfs={L_sfs.item():.4f} demo={L_demo.item():.4f} "
f"tau={model.topology_decoder.tau:.3f}")
Training Diagnostics
Monitoring training requires tracking each loss component independently, plus several diagnostic metrics:
Metric |
What it tells you |
|---|---|
Topology accuracy (Robinson-Foulds distance) |
Whether the predicted tree topologies match ground truth. Should decrease rapidly in Phase 1 and plateau in Phase 3. |
Time calibration (predicted gamma coverage) |
Whether the predicted gamma distributions are well-calibrated: the fraction of true node times falling within the predicted 90% credible interval should be close to 0.9. |
SFS residuals |
The per-frequency-class difference between predicted and observed SFS. Systematic biases (e.g., too few singletons) indicate structural errors in the predicted ARG. |
:math:`N_e(t)` RMSE (log-scale) |
Root mean squared error between predicted and true \(N_e(t)\) on a log scale. Should decrease steadily throughout Phases 2–3. |
Gumbel-softmax entropy |
The entropy of the attention weights in the topology decoder. Should decrease as \(\tau\) anneals, indicating that parent assignments are becoming more confident. |
When to stop training
Training is complete when (1) the validation loss plateaus for 50k steps, (2) the time calibration is within 5% of the nominal level across all time scales, and (3) the \(N_e(t)\) RMSE on a held-out validation set of 1,000 simulations stops improving. In practice, Phase 3 accounts for most of the training time. Phase 4 (SLiM robustness) is optional and mainly relevant for applications where selection is expected.