Architecture

The movement of a grande complication is built in stages. Each stage transforms the energy of the mainspring into a more refined form, until the final stage moves the hands with perfect precision.

Mainspring processes a genotype matrix through four stages, each producing an increasingly refined representation of the evolutionary history encoded in the data. The stages mirror the factorization of the full posterior:

\[p(\mathcal{A}, N_e \mid \mathbf{D}) = \underbrace{p(\mathcal{T} \mid \mathbf{D})}_{\text{topology}} \;\cdot\; \underbrace{p(\mathbf{t} \mid \mathcal{T}, \mathbf{D})}_{\text{dating}} \;\cdot\; \underbrace{p(N_e \mid \mathcal{T}, \mathbf{t})}_{\text{demography}}\]

where \(\mathcal{A} = (\mathcal{T}, \mathbf{t})\) is the ARG decomposed into topology \(\mathcal{T}\) and node times \(\mathbf{t}\), and the genomic encoder provides the shared representation from which all three factors are decoded.

┌─────────────────────────────────────────────────────────┐
│                                                         │
│  STAGE 1: GENOMIC ENCODER                               │
│  ┌─────────────┐   ┌─────────────┐   ┌──────────────┐  │
│  │  Embedding   │──▶│ Set Transf. │──▶│ Sliding-Win  │  │
│  │  (per site)  │   │ (samples)   │   │ Attn (pos.)  │  │
│  └─────────────┘   └─────────────┘   └──────────────┘  │
│           D ∈ {0,1}^{n×L}  ──▶  Z ∈ R^{n×L×d}         │
│                                                         │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  STAGE 2: TOPOLOGY DECODER                              │
│  ┌──────────────┐   ┌──────────────┐  ┌─────────────┐  │
│  │ Cross-attn   │──▶│ Breakpoint   │──▶│ Hard attn   │  │
│  │ (Li&Stephens)│   │ detector     │   │ (Gumbel-SM) │  │
│  └──────────────┘   └──────────────┘  └─────────────┘  │
│           Z  ──▶  T = {(parent[], left, right)}         │
│                                                         │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  STAGE 3: DATING GNN                                    │
│  ┌──────────────┐   ┌──────────────┐  ┌─────────────┐  │
│  │ Node/edge    │──▶│ UP/DOWN msg  │──▶│ Gamma heads │  │
│  │ features     │   │ passing (×K) │   │ (α_v, β_v)  │  │
│  └──────────────┘   └──────────────┘  └─────────────┘  │
│           T + muts  ──▶  t_v ~ Gamma(α_v, β_v)         │
│                                                         │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  STAGE 4: DEMOGRAPHIC DECODER                           │
│  ┌──────────────┐   ┌──────────────┐  ┌─────────────┐  │
│  │ Coalescence  │──▶│ Normalizing  │──▶│ SFS loss    │  │
│  │ time hist.   │   │ flow         │   │ (physics)   │  │
│  └──────────────┘   └──────────────┘  └─────────────┘  │
│           t_v  ──▶  q(N_e(t))                           │
│                                                         │
└─────────────────────────────────────────────────────────┘

Stage 1: Genomic Encoder

The encoder transforms the raw genotype matrix \(\mathbf{D} \in \{0,1\}^{n \times L}\) into a dense representation \(\mathbf{Z} \in \mathbb{R}^{n \times L \times d}\) that captures both inter-sample relationships and spatial correlations along the genome.

Embedding Layer

Each site is embedded independently. The input at site \(\ell\) is the column vector \(\mathbf{d}_\ell = (d_{1,\ell}, \ldots, d_{n,\ell})^\top \in \{0,1\}^n\). Each sample’s binary allele is embedded into \(\mathbb{R}^d\):

\[\mathbf{e}_{i,\ell} = \mathbf{W}_{\text{allele}}[d_{i,\ell}] + \text{RFF}(\ell) + \mathbf{W}_{\text{freq}} \cdot \hat{f}_\ell\]

where \(\mathbf{W}_{\text{allele}} \in \mathbb{R}^{2 \times d}\) is an allele embedding table, \(\text{RFF}(\ell)\) is a random Fourier feature positional encoding (see Design Principles – One Per Timepiece, Principle 8), and \(\hat{f}_\ell = \frac{1}{n}\sum_i d_{i,\ell}\) is the sample allele frequency at site \(\ell\), projected through \(\mathbf{W}_{\text{freq}} \in \mathbb{R}^d\).

class GenomicEmbedding(nn.Module):
    def __init__(self, d_model, rff_sigma=10.0):
        super().__init__()
        self.allele_embed = nn.Embedding(2, d_model)
        self.freq_proj = nn.Linear(1, d_model, bias=False)
        self.rff = RandomFourierPositionalEncoding(d_model, sigma=rff_sigma)

    def forward(self, D):
        B, n, L = D.shape
        allele = self.allele_embed(D)                          # (B, n, L, d)
        positions = torch.arange(L, device=D.device).float()
        pos_enc = self.rff(positions)                           # (L, d)
        freq = D.float().mean(dim=1, keepdim=True).unsqueeze(-1)  # (B, 1, L, 1)
        freq_enc = self.freq_proj(freq)                         # (B, 1, L, d)
        return allele + pos_enc.unsqueeze(0).unsqueeze(0) + freq_enc

Set Transformer over Samples

At each site, the \(n\) sample embeddings are processed by an induced set attention block (ISAB) that is permutation-equivariant over the sample dimension. This implements Principle 2 from Design Principles – One Per Timepiece.

The ISAB uses \(m\) inducing points to reduce the \(O(n^2)\) cost of full self-attention to \(O(nm)\):

\[\mathbf{H} = \text{ISAB}_m(\mathbf{E}_\ell) = \text{MAB}(\mathbf{E}_\ell,\; \text{MAB}(\mathbf{I}, \mathbf{E}_\ell))\]

where \(\text{MAB}(\mathbf{X}, \mathbf{Y}) = \text{LayerNorm}(\mathbf{X} + \text{MultiheadAttention}(\mathbf{X}, \mathbf{Y}, \mathbf{Y}))\) is a multihead attention block, \(\mathbf{I} \in \mathbb{R}^{m \times d}\) are learned inducing points, and \(\mathbf{E}_\ell \in \mathbb{R}^{n \times d}\) are the sample embeddings at site \(\ell\).

class SampleEncoder(nn.Module):
    def __init__(self, d_model, n_heads, n_inducing, n_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            InducedSetAttention(d_model, n_heads, n_inducing)
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        for layer in self.layers:
            x = x + layer(x)
        return self.norm(x)

Sliding-Window Positional Attention

After the Set Transformer processes each site independently over samples, we apply sliding-window self-attention along the genomic axis (Principle 1). Each sample’s sequence of \(L\) site embeddings is treated as a sequence, and attention is restricted to a window of \(w\) positions:

class GenomicEncoder(nn.Module):
    def __init__(self, d_model, n_heads, n_layers,
                 n_inducing=32, window_size=512):
        super().__init__()
        self.embedding = GenomicEmbedding(d_model)
        self.sample_encoder = SampleEncoder(d_model, n_heads, n_inducing, 2)
        self.positional_layers = nn.ModuleList([
            SlidingWindowAttention(d_model, n_heads, window_size)
            for _ in range(n_layers)
        ])
        self.ffn_layers = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(d_model),
                nn.Linear(d_model, 4 * d_model),
                nn.GELU(),
                nn.Linear(4 * d_model, d_model),
            )
            for _ in range(n_layers)
        ])

    def forward(self, D):
        Z = self.embedding(D)                     # (B, n, L, d)
        B, n, L, d = Z.shape
        Z = Z.permute(0, 2, 1, 3).reshape(B * L, n, d)
        Z = self.sample_encoder(Z)                # Set Transformer over samples
        Z = Z.reshape(B, L, n, d).permute(0, 2, 1, 3)
        Z = Z.reshape(B * n, L, d)
        for attn, ffn in zip(self.positional_layers, self.ffn_layers):
            Z = Z + attn(Z)                       # sliding-window attention
            Z = Z + ffn(Z)
        Z = Z.reshape(B, n, L, d)
        return Z

Stage 2: Topology Decoder

The topology decoder converts the encoder’s latent representation into a sequence of local tree topologies with breakpoints. This is the most structurally novel component: it implements a learned Li & Stephens model (Principle 5).

Cross-Attention as Copying

At each genomic position \(\ell\), every sample \(i\) computes attention weights over all other samples. The attention weights represent the probability that sample \(i\) is “copying from” sample \(j\) at this position – the neural analogue of the Li & Stephens transition probabilities.

\[\mathbf{q}_i^\ell = \mathbf{W}_Q \mathbf{z}_{i,\ell}, \quad \mathbf{k}_j^\ell = \mathbf{W}_K \mathbf{z}_{j,\ell}, \quad \alpha_{ij}^\ell = \text{softmax}_j\!\left(\frac{\mathbf{q}_i^{\ell\top} \mathbf{k}_j^\ell}{\sqrt{d}}\right)\]

The attention matrix \(\mathbf{A}^\ell \in \mathbb{R}^{n \times n}\) at each position encodes the copying relationships. In a true Li & Stephens model, each row of this matrix would be a one-hot vector (each sample copies from exactly one source). We relax this to soft attention during training and gradually harden it.

Breakpoint Detection

Tree topology changes at recombination breakpoints. The breakpoint detector is a 1D convolution along the genomic axis that identifies positions where the attention pattern changes significantly:

\[b_\ell = \sigma\!\left(\mathbf{w}_b^\top \text{Conv1D}\!\bigl(\|\mathbf{A}^\ell - \mathbf{A}^{\ell-1}\|_F,\; \ldots\bigr) + c_b\right)\]

where \(b_\ell \in [0, 1]\) is the breakpoint probability at position \(\ell\) and \(\|\cdot\|_F\) is the Frobenius norm of the change in attention pattern.

class BreakpointDetector(nn.Module):
    def __init__(self, d_model, kernel_size=5):
        super().__init__()
        self.proj = nn.Linear(d_model, 1)
        self.conv = nn.Conv1d(1, 1, kernel_size, padding=kernel_size // 2)

    def forward(self, Z_diff):
        x = self.proj(Z_diff).squeeze(-1).unsqueeze(1)  # (B, 1, L)
        return torch.sigmoid(self.conv(x)).squeeze(1)     # (B, L)

Hard Attention via Gumbel-Softmax

To produce discrete tree topologies, we need hard parent assignments. During training, we use the Gumbel-softmax trick to maintain differentiability:

class TopologyDecoder(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.breakpoint_det = BreakpointDetector(d_model)
        self.tau = 1.0  # annealed during training

    def forward(self, Z, hard=False):
        B, n, L, d = Z.shape
        parent_logits = []
        for ell in range(L):
            q = Z[:, :, ell, :]          # (B, n, d)
            k = Z[:, :, ell, :]          # (B, n, d)
            scores = torch.bmm(q, k.transpose(1, 2)) / (d ** 0.5)
            scores.diagonal(dim1=1, dim2=2).fill_(float('-inf'))

            if hard:
                parents = scores.argmax(dim=-1)
            else:
                parents = F.gumbel_softmax(scores, tau=self.tau, hard=True, dim=-1)
            parent_logits.append(scores)

        parent_logits = torch.stack(parent_logits, dim=2)  # (B, n, L, n)

        Z_diff = Z[:, :, 1:, :] - Z[:, :, :-1, :]
        Z_diff_pooled = Z_diff.mean(dim=1)                  # pool over samples
        breakpoints = self.breakpoint_det(Z_diff_pooled)

        return parent_logits, breakpoints

At inference time (\(\tau \to 0\) or hard=True), the Gumbel-softmax collapses to argmax, producing deterministic parent assignments. The topology is then assembled into contiguous tree segments separated by breakpoints.

From attention to tree topology

The attention matrix \(\mathbf{A}^\ell\) does not directly encode a valid tree. To obtain a tree, we apply a greedy bottom-up construction: starting from the leaves, we iteratively merge the pair with the highest mutual attention weight, creating an internal node. The process continues until all samples are connected. This is reminiscent of hierarchical clustering, but the similarity metric is learned end-to-end.

Stage 3: Dating GNN

Given the predicted topology, the dating GNN assigns times to internal nodes using learned message passing. This is the neural analogue of tsdate’s inside-outside algorithm (Principle 4), with gamma output heads (Principle 7) and per-segment sufficient statistics as input features (Principle 9).

Node and Edge Features

Each node \(v\) in a local tree is initialized with a feature vector:

\[\begin{split}\mathbf{h}_v^{(0)} = \begin{cases} \mathbf{W}_{\text{leaf}} \mathbf{z}_{v,\ell} & \text{if } v \text{ is a leaf} \\ \mathbf{W}_{\text{init}} [\hat{t}_v;\; \log \hat{t}_v;\; \mathbf{0}] & \text{if } v \text{ is internal} \end{cases}\end{split}\]

where \(\mathbf{z}_{v,\ell}\) is the encoder output for leaf \(v\) at the midpoint of the tree’s genomic span, and \(\hat{t}_v\) is an initial time estimate from the Threads-style natural estimator.

Each edge \((u, v)\) carries features:

\[\mathbf{f}_{uv} = \mathbf{W}_{\text{edge}} [m_{uv};\; s_{uv};\; \hat{t}_{uv};\; \log m_{uv};\; \log s_{uv};\; n_{uv}]\]

where \(m_{uv}\) is the mutation count, \(s_{uv}\) the genomic span, \(\hat{t}_{uv}\) the natural time estimate, and \(n_{uv}\) the number of descendant leaves.

UP/DOWN Message Passing

The GNN alternates between UP passes (children to parent, analogous to tsdate’s inside pass) and DOWN passes (parent to children, analogous to the outside pass):

\[\mathbf{m}_{c \to p}^{(k)} = \text{MLP}_{\text{up}}\!\bigl( [\mathbf{h}_c^{(k)};\; \mathbf{f}_{cp}]\bigr) \qquad \text{(UP message)}\]
\[\mathbf{m}_{p \to c}^{(k)} = \text{MLP}_{\text{down}}\!\bigl( [\mathbf{h}_p^{(k)};\; \mathbf{f}_{pc}]\bigr) \qquad \text{(DOWN message)}\]
\[\mathbf{h}_v^{(k+1)} = \text{GRU}\!\left(\mathbf{h}_v^{(k)},\; \sum_{u \in \text{children}(v)} \mathbf{m}_{u \to v}^{(k)} + \mathbf{m}_{\text{parent}(v) \to v}^{(k)}\right)\]

The GRU (gated recurrent unit) update prevents the node features from drifting too far from their initial values while allowing iterative refinement. After \(K\) rounds (typically \(K = 6\)), the node features are decoded into gamma parameters.

class DatingGNN(nn.Module):
    def __init__(self, d_model, n_rounds=6):
        super().__init__()
        self.n_rounds = n_rounds
        self.node_init = nn.Linear(d_model, d_model)
        self.edge_encoder = nn.Linear(6, d_model)
        self.up_msg = nn.Sequential(nn.Linear(2 * d_model, d_model), nn.ReLU(),
                                    nn.Linear(d_model, d_model))
        self.down_msg = nn.Sequential(nn.Linear(2 * d_model, d_model), nn.ReLU(),
                                      nn.Linear(d_model, d_model))
        self.gru = nn.GRUCell(d_model, d_model)
        self.alpha_head = nn.Linear(d_model, 1)
        self.beta_head = nn.Linear(d_model, 1)

    def forward(self, node_features, edge_features, parent_array):
        h = self.node_init(node_features)
        f = self.edge_encoder(edge_features)

        for k in range(self.n_rounds):
            msg = torch.zeros_like(h)
            for child, parent in enumerate(parent_array):
                if parent < 0:
                    continue
                up = self.up_msg(torch.cat([h[child], f[child]], dim=-1))
                msg[parent] += up
                down = self.down_msg(torch.cat([h[parent], f[child]], dim=-1))
                msg[child] += down
            h = self.gru(msg, h)

        alpha = F.softplus(self.alpha_head(h)) + 1.0
        beta = torch.exp(self.beta_head(h))
        return alpha, beta

Gamma Output Heads

The final node features \(\mathbf{h}_v^{(K)}\) are decoded into gamma parameters \((\alpha_v, \beta_v)\):

\[\alpha_v = \text{softplus}(\mathbf{w}_\alpha^\top \mathbf{h}_v^{(K)}) + 1, \qquad \beta_v = \exp(\mathbf{w}_\beta^\top \mathbf{h}_v^{(K)})\]

The predicted time distribution for node \(v\) is then \(t_v \sim \text{Gamma}(\alpha_v, \beta_v)\), with mean \(\mathbb{E}[t_v] = \alpha_v / \beta_v\) and variance \(\text{Var}(t_v) = \alpha_v / \beta_v^2\).

Cross-Tree Consistency

Adjacent local trees share most of their topology and node times. To enforce consistency, we add a cross-tree regularizer that penalizes large changes in predicted node times between adjacent trees:

\[\mathcal{L}_{\text{consistency}} = \sum_{\ell=1}^{T-1} \sum_{v \in \mathcal{V}_\ell \cap \mathcal{V}_{\ell+1}} \left(\log \frac{\alpha_v^\ell}{\beta_v^\ell} - \log \frac{\alpha_v^{\ell+1}}{\beta_v^{\ell+1}}\right)^2\]

where \(\mathcal{V}_\ell\) is the set of nodes in local tree \(\ell\) and the intersection identifies nodes shared between adjacent trees.

Stage 4: Demographic Decoder

The final stage maps the inferred coalescence-time distribution to a posterior over \(N_e(t)\) trajectories. This is where the ARG’s status as a sufficient statistic (Principle 3) pays off: the demographic decoder operates entirely on the predicted coalescence times, not on the raw genotype matrix.

Coalescence-Time Histogram

From the dated ARG, we extract a histogram of coalescence times. For each internal node \(v\) at time \(t_v\) with \(k_v\) children, we record \(k_v - 1\) coalescence events at time \(t_v\). Binning these into \(B\) logarithmically-spaced time intervals gives a vector \(\mathbf{c} \in \mathbb{R}^B\):

\[c_b = \sum_{v \in \text{internal nodes}} (k_v - 1) \cdot \mathbf{1}[t_v \in \text{bin } b]\]

This histogram, together with the predicted SFS from the ARG, forms the input to the normalizing flow.

Normalizing Flow

The demographic decoder is a conditional normalizing flow that transforms a simple base distribution (standard normal) into a posterior over \(N_e(t)\) functions, conditioned on the coalescence-time histogram and SFS:

\[\mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}), \quad N_e(t) = g_\phi(\mathbf{z} \mid \mathbf{c}, \text{SFS})\]

where \(g_\phi\) is an invertible neural network parameterized by \(\phi\). The \(N_e(t)\) trajectory is represented as a vector of \(B\) values on the same log-spaced time grid, with linear interpolation between grid points.

class DemographicDecoder(nn.Module):
    def __init__(self, n_time_bins, n_flow_layers, d_cond):
        super().__init__()
        self.condition_net = nn.Sequential(
            nn.Linear(2 * n_time_bins, d_cond),
            nn.ReLU(),
            nn.Linear(d_cond, d_cond),
        )
        self.flow_layers = nn.ModuleList([
            AffineCouplingLayer(n_time_bins, d_cond)
            for _ in range(n_flow_layers)
        ])

    def forward(self, coal_histogram, sfs, n_samples=1):
        cond = self.condition_net(torch.cat([coal_histogram, sfs], dim=-1))
        z = torch.randn(n_samples, coal_histogram.size(-1))
        log_det = 0.0
        for layer in self.flow_layers:
            z, ld = layer(z, cond)
            log_det += ld
        ne_trajectory = F.softplus(z)
        return ne_trajectory, log_det

SFS Auxiliary Loss

The predicted ARG implies a predicted SFS, which must be consistent with the observed SFS. This consistency check is the physics-informed regularizer from Principle 6:

\[\mathcal{L}_{\text{SFS}} = \sum_{k=1}^{n-1} \left( \text{SFS}_{\text{pred}}[k] \cdot \mu - \text{SFS}_{\text{obs}}[k] \right)^2 / \text{SFS}_{\text{obs}}[k]\]

where the predicted SFS is computed differentiably from the ARG branch lengths and descendant counts, and the observed SFS is computed directly from the genotype matrix.

Why the SFS loss matters

Without the SFS loss, the network can produce ARGs that correctly reconstruct the topology and approximate the node times but systematically miscount the number of lineages at each frequency class. The SFS loss acts as a global consistency check: it catches errors in the predicted ARG that local losses (topology accuracy, node time likelihood) might miss. This is analogous to how a watchmaker, after assembling each gear individually, checks that the overall gear train produces the correct time – a global test that catches assembly errors invisible at the component level.

Putting It All Together

The complete Mainspring model chains all four stages:

class Mainspring(nn.Module):
    def __init__(self, d_model=256, n_heads=8, n_encoder_layers=6,
                 n_gnn_rounds=6, n_time_bins=64, n_flow_layers=8):
        super().__init__()
        self.encoder = GenomicEncoder(d_model, n_heads, n_encoder_layers)
        self.topology_decoder = TopologyDecoder(d_model, n_heads)
        self.dating_gnn = DatingGNN(d_model, n_gnn_rounds)
        self.demographic_decoder = DemographicDecoder(
            n_time_bins, n_flow_layers, d_cond=128
        )

    def forward(self, D, hard=False):
        Z = self.encoder(D)
        parent_logits, breakpoints = self.topology_decoder(Z, hard=hard)
        topology = self.extract_trees(parent_logits, breakpoints)
        node_feats, edge_feats, parent_arrays = self.build_gnn_input(
            Z, topology
        )
        alphas, betas = self.dating_gnn(node_feats, edge_feats, parent_arrays)
        times = alphas / betas  # point estimate = gamma mean
        coal_hist = self.build_coalescence_histogram(times, topology)
        pred_sfs = self.compute_sfs(times, topology)
        ne_posterior, log_det = self.demographic_decoder(coal_hist, pred_sfs)
        return {
            'topology': topology,
            'breakpoints': breakpoints,
            'alpha': alphas,
            'beta': betas,
            'times': times,
            'ne_posterior': ne_posterior,
            'flow_log_det': log_det,
            'predicted_sfs': pred_sfs,
        }

Computational Complexity

Per-stage computational complexity

Stage

Complexity

Bottleneck

Genomic Encoder

\(O(n^2 L d + n L w d)\)

Set Transformer (\(n^2\) per site) + sliding-window attention (\(w\) per position)

Topology Decoder

\(O(n^2 L d)\)

Cross-attention at each site

Dating GNN

\(O(K n L d)\)

\(K\) message-passing rounds on trees with \(O(n)\) nodes

Demographic Decoder

\(O(B^2)\)

Normalizing flow on \(B\) time bins (negligible)

Total: \(O(n^2 L d)\), linear in sequence length and quadratic in sample count. For typical applications (\(n \leq 100\), \(L \sim 10^4\)), this is feasible on a single GPU.