.. _escapement_architecture:

==============
Architecture
==============

   *Four wheels, precisely meshed. The escape wheel receives energy from the
   gear train. The pallet fork converts rotational motion into oscillation.
   The balance spring stores and returns energy. The regulator adjusts the
   rate. Remove any one, and the mechanism stops.*

Escapement has four modules, each with a distinct role in the variational
inference pipeline. Unlike :ref:`Mainspring <mainspring_architecture>`, which is
trained on simulated (input, target) pairs, every module here is trained
end-to-end by maximizing the ELBO on observed data. There are no ground-truth
labels. The coalescent likelihood *is* the supervision signal.

.. code-block:: text

   Module 1                Module 2                Module 3              Module 4
   GENEALOGY               VARIATIONAL TREE        DIFFERENTIABLE        DEMOGRAPHIC
   ENCODER                 POSTERIOR               LIKELIHOOD            INFERENCE
   (escape wheel)          (pallet fork)           (balance spring)      (regulator)

   D ∈ {0,1}^{n×L}        h ∈ R^{n×L×d}         τ ~ q(τ|D,φ)          ELBO terms
        |                       |                       |                    |
        v                       v                       v                    v
   ┌──────────┐          ┌──────────────┐        ┌──────────────┐     ┌──────────┐
   │Transformer│ ──────▶ │ Topology:    │──────▶ │ log P(D|τ,μ) │     │ N_e(t):  │
   │ sample ×  │         │  Gumbel-SM   │  τ     │ log P(τ|Ne,ρ)│◀────│ piecewise│
   │ position  │         │ Branches:    │        │ H[q]         │     │ or spline│
   │           │         │  LogNormal   │        │              │     │          │
   │           │         │ Breakpoints: │        │ = ELBO       │     │          │
   │           │         │  Bernoulli   │        │ (pure math)  │     │          │
   └──────────┘          └──────────────┘        └──────────────┘     └──────────┘


Module 1: Genealogy Encoder
==============================

The encoder transforms the raw genotype matrix
:math:`\mathbf{D} \in \{0,1\}^{n \times L}` into latent vectors
:math:`\mathbf{h} \in \mathbb{R}^{n \times L \times d}`. It must capture
two kinds of structure:

- **Inter-sample relationships** at each genomic position (which samples share
  recent common ancestry here?)
- **Spatial correlations** along the genome (how does ancestry change across
  positions due to recombination?)

The architecture is a Transformer that alternates between attention over samples
and attention over positions -- the same dual-axis design as
:ref:`Mainspring <mainspring_architecture>`, but optimized against the ELBO
rather than simulation-matching losses.

Sample Attention
------------------

At each genomic position, the :math:`n` sample embeddings are processed by
multi-head self-attention. This is permutation-equivariant over samples: if the
sample order in :math:`\mathbf{D}` is permuted, the output embeddings are
permuted identically. This encodes the exchangeability of coalescent samples.

.. math::

   \text{SampleAttn}(\mathbf{E}_\ell) = \text{LayerNorm}\!\left(
   \mathbf{E}_\ell + \text{MHA}(\mathbf{E}_\ell, \mathbf{E}_\ell,
   \mathbf{E}_\ell)\right)

where :math:`\mathbf{E}_\ell \in \mathbb{R}^{n \times d}` is the embedding
matrix at position :math:`\ell` and MHA is multi-head attention.

The attention weights :math:`\alpha_{ij}^\ell` at each position have a direct
interpretation: they measure how much sample :math:`i` "looks at" sample
:math:`j`. In a well-trained model, high attention corresponds to recent common
ancestry -- samples that coalesce early in the tree attend strongly to each
other. This is the neural analogue of the Li & Stephens copying probabilities
from :ref:`tsinfer <tsinfer_timepiece>`.

.. code-block:: python

   class SampleAttention(nn.Module):
       def __init__(self, d_model, n_heads=4, dropout=0.1):
           super().__init__()
           self.n_heads = n_heads
           self.head_dim = d_model // n_heads
           self.qkv = nn.Linear(d_model, 3 * d_model)
           self.out_proj = nn.Linear(d_model, d_model)
           self.dropout = nn.Dropout(dropout)
           self.norm = nn.LayerNorm(d_model)

       def forward(self, x):
           residual = x
           B, N, D = x.shape
           qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
           qkv = qkv.permute(2, 0, 3, 1, 4)
           q, k, v = qkv.unbind(0)
           attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
           attn = F.softmax(attn, dim=-1)
           attn = self.dropout(attn)
           out = (attn @ v).transpose(1, 2).reshape(B, N, D)
           return self.norm(residual + self.out_proj(out))

Sliding-Window Positional Attention
--------------------------------------

After processing inter-sample relationships at each position, the encoder
processes spatial correlations along the genome using sliding-window
self-attention. For each sample, its sequence of :math:`L` positional embeddings
is treated as a sequence, with attention restricted to a window of :math:`w`
positions.

The window size :math:`w` should be approximately :math:`1/\rho` (the expected
distance between recombination events), measured in number of segregating sites.
Within this window, the local tree is approximately constant, and the attention
mechanism can detect patterns of linkage disequilibrium that reveal the local
genealogy.

.. math::

   \text{PosAttn}(\mathbf{h}_i) = \text{LayerNorm}\!\left(
   \mathbf{h}_i + \text{MHA}_w(\mathbf{h}_i, \mathbf{h}_i,
   \mathbf{h}_i)\right)

where :math:`\mathbf{h}_i \in \mathbb{R}^{L \times d}` is the embedding
sequence for sample :math:`i`, and :math:`\text{MHA}_w` restricts attention
to positions within distance :math:`w`.

.. code-block:: python

   class PositionalAttention(nn.Module):
       def __init__(self, d_model, n_heads=4, window=64, dropout=0.1):
           super().__init__()
           self.window = window
           self.n_heads = n_heads
           self.head_dim = d_model // n_heads
           self.qkv = nn.Linear(d_model, 3 * d_model)
           self.out_proj = nn.Linear(d_model, d_model)
           self.dropout = nn.Dropout(dropout)
           self.norm = nn.LayerNorm(d_model)

       def forward(self, x):
           residual = x
           B, L, D = x.shape
           W = min(self.window, L)
           qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim)
           qkv = qkv.permute(2, 0, 3, 1, 4)
           q, k, v = qkv.unbind(0)
           attn_logits = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
           positions = torch.arange(L, device=x.device)
           mask = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs() > W
           attn_logits = attn_logits.masked_fill(
               mask.unsqueeze(0).unsqueeze(0), float("-inf")
           )
           attn = F.softmax(attn_logits, dim=-1)
           attn = self.dropout(attn)
           out = (attn @ v).transpose(1, 2).reshape(B, L, D)
           return self.norm(residual + self.out_proj(out))

Full Encoder
--------------

The complete encoder stacks :math:`K` blocks, each containing one sample
attention layer, one positional attention layer, and a feedforward network:

.. code-block:: python

   class GenealogyEncoder(nn.Module):
       def __init__(self, d_model=64, n_heads=4, n_layers=2,
                    window=64, dropout=0.1):
           super().__init__()
           self.allele_embed = nn.Linear(1, d_model)
           self.layers = nn.ModuleList()
           for _ in range(n_layers):
               self.layers.append(nn.ModuleDict({
                   "sample_attn": SampleAttention(d_model, n_heads, dropout),
                   "pos_attn": PositionalAttention(d_model, n_heads, window, dropout),
                   "ffn": nn.Sequential(
                       nn.Linear(d_model, 4 * d_model), nn.GELU(),
                       nn.Linear(4 * d_model, d_model), nn.Dropout(dropout)),
                   "ffn_norm": nn.LayerNorm(d_model),
               }))

       def forward(self, genotypes):
           B, N, L = genotypes.shape
           x = self.allele_embed(genotypes.unsqueeze(-1))
           D = x.shape[-1]
           for layer in self.layers:
               x_pos = x.permute(0, 2, 1, 3).reshape(B * L, N, D)
               x_pos = layer["sample_attn"](x_pos)
               x = x_pos.reshape(B, L, N, D).permute(0, 2, 1, 3)
               x_samp = x.reshape(B * N, L, D)
               x_samp = layer["pos_attn"](x_samp)
               x = x_samp.reshape(B, N, L, D)
               residual = x
               x = layer["ffn_norm"](residual + layer["ffn"](x))
           return x

.. admonition:: Shared encoder, different objective

   Escapement's encoder architecture is nearly identical to Mainspring's
   genomic encoder. The crucial difference is the training objective. Mainspring
   trains the encoder to predict the true ARG from simulations. Escapement
   trains it to produce latent vectors from which the variational posterior can
   generate genealogies that maximize the ELBO. The same architecture, trained
   with different losses, learns different representations.


Module 2: Variational Tree Posterior
======================================

The variational posterior :math:`q(\tau \mid \mathbf{D}, \phi)` maps the
encoder's latent vectors to a distribution over tree sequences. A tree sequence
has three components -- topology, branch lengths, and breakpoints -- and the
posterior factorizes accordingly:

.. math::

   q(\tau \mid \mathbf{D}, \phi) = q_{\text{topo}}(\pi \mid \mathbf{h})
   \cdot q_{\text{branch}}(t \mid \mathbf{h})
   \cdot q_{\text{break}}(b \mid \mathbf{h})

This mean-field factorization is an approximation: in reality, topology and
branch lengths are correlated (e.g., star-like trees imply recent coalescence).
The approximation enables tractable entropy computation and efficient sampling.

Topology: Gumbel-Softmax Parent Assignments
----------------------------------------------

At each genomic position :math:`\ell`, each sample :math:`i` chooses a parent
:math:`j \neq i` from the other samples. The parent assignment probabilities
are computed by scaled dot-product attention:

.. math::

   \alpha_{ij}^\ell = \frac{\exp(\mathbf{q}_i^{\ell\top} \mathbf{k}_j^\ell / \sqrt{d})}
   {\sum_{k \neq i} \exp(\mathbf{q}_i^{\ell\top} \mathbf{k}_k^\ell / \sqrt{d})}

where :math:`\mathbf{q}_i^\ell = \mathbf{W}_Q \mathbf{h}_{i,\ell}` and
:math:`\mathbf{k}_j^\ell = \mathbf{W}_K \mathbf{h}_{j,\ell}`. Self-assignment
is masked out (:math:`\alpha_{ii}^\ell = 0`).

During training, the parent assignment is sampled via Gumbel-softmax with
temperature :math:`\tau`:

.. code-block:: python

   class TopologyHead(nn.Module):
       def __init__(self, d_model):
           super().__init__()
           self.query_proj = nn.Linear(d_model, d_model)
           self.key_proj = nn.Linear(d_model, d_model)

       def forward(self, h, temperature=1.0, hard=False):
           Q = self.query_proj(h)
           K = self.key_proj(h)
           N = h.shape[1]
           logits = (Q @ K.transpose(-2, -1)) / (h.shape[-1] ** 0.5)
           mask = torch.eye(N, device=h.device, dtype=torch.bool).unsqueeze(0)
           logits = logits.masked_fill(mask, float("-inf"))
           parent_probs = gumbel_softmax_sample(logits, temperature, hard)
           log_probs = F.log_softmax(logits, dim=-1)
           chosen_log_probs = (parent_probs * log_probs).sum(dim=-1)
           return parent_probs, chosen_log_probs

The topology entropy is the categorical entropy of the parent assignment
distribution:

.. math::

   H_{\text{topo}}[q] = -\sum_{\ell=1}^{L} \sum_{i=1}^{n} \sum_{j \neq i}
   \alpha_{ij}^\ell \log \alpha_{ij}^\ell

Branch Lengths: Log-Normal Reparameterization
-------------------------------------------------

Coalescence times are positive and typically span several orders of magnitude
(from tens to millions of generations). The log-normal distribution is a natural
choice:

.. math::

   t_{i,\ell} \sim \text{LogNormal}(\mu_{i,\ell}, \sigma_{i,\ell})

where :math:`\mu_{i,\ell}` and :math:`\sigma_{i,\ell}` are predicted by an
MLP from the latent vectors. The reparameterization trick enables gradient
flow:

.. math::

   t_{i,\ell} = \exp(\mu_{i,\ell} + \sigma_{i,\ell} \cdot \epsilon), \qquad
   \epsilon \sim \mathcal{N}(0, 1)

.. code-block:: python

   class BranchLengthHead(nn.Module):
       def __init__(self, d_model, expected_tmrca=20000.0):
           super().__init__()
           self.log_expected = math.log(max(expected_tmrca, 1.0))
           self.mlp = nn.Sequential(
               nn.Linear(d_model, d_model), nn.GELU(),
               nn.Linear(d_model, 2))
           with torch.no_grad():
               self.mlp[-1].bias[0] = self.log_expected
               self.mlp[-1].bias[1] = -1.0

       def forward(self, h):
           raw = self.mlp(h)
           log_mean = raw[..., 0]
           log_std = F.softplus(raw[..., 1]) + 1e-4
           return log_mean, log_std

The initial bias is set so that the initial branch-length predictions are
centered at :math:`\log(2 N_e)` -- the expected TMRCA for a pair of lineages
under the coalescent. This prevents the optimization from starting in a regime
where all branches are unreasonably short or long.

The branch-length entropy is the log-normal entropy, which has a closed-form
expression:

.. math::

   H_{\text{branch}}[q] = \sum_{i,\ell} \left[\mu_{i,\ell} + \frac{1}{2}
   + \log \sigma_{i,\ell} + \frac{1}{2}\log(2\pi)\right]

Breakpoints: Bernoulli Probabilities
----------------------------------------

At each position :math:`\ell`, a recombination breakpoint occurs with
probability :math:`b_\ell \in [0, 1]`. The breakpoint detector compares
adjacent latent vectors and predicts whether the local tree changes:

.. code-block:: python

   class BreakpointHead(nn.Module):
       def __init__(self, d_model):
           super().__init__()
           self.mlp = nn.Sequential(
               nn.Linear(2 * d_model, d_model), nn.GELU(),
               nn.Linear(d_model, 1))

       def forward(self, h):
           h_left = h[:, :, :-1, :]
           h_right = h[:, :, 1:, :]
           pair = torch.cat([h_left, h_right], dim=-1)
           logits = self.mlp(pair).squeeze(-1)
           probs = torch.sigmoid(logits.mean(dim=1))
           return probs

The breakpoint entropy is the Bernoulli entropy:

.. math::

   H_{\text{break}}[q] = -\sum_{\ell=1}^{L-1}
   \left[b_\ell \log b_\ell + (1 - b_\ell) \log(1 - b_\ell)\right]


Module 3: Differentiable Likelihood
======================================

Module 3 contains no neural networks. It is pure math: given a sampled
genealogy :math:`\tau \sim q`, it computes the three ELBO terms. This module
is the balance spring of the mechanism -- it provides the restoring force that
pulls the variational posterior toward genealogies consistent with the data and
the coalescent model.

The full derivation of Module 3 is in
:ref:`The Differentiable Likelihood <escapement_likelihood>`. Here we summarize
the three components:

**Mutation log-likelihood.** For each sample :math:`i` at position :math:`\ell`,
the probability of the observed allele given the proposed parent :math:`j` and
coalescence time :math:`t` is:

.. math::

   P(d_{i,\ell} \neq d_{j,\ell}) = 1 - \exp(-2 \mu \cdot t_{i,\ell} \cdot s)

where :math:`s` is the span in base pairs and the factor of 2 accounts for
mutations on both lineages. The total mutation log-likelihood sums over all
samples and positions.

**Coalescent log-prior.** For constant :math:`N_e`, the TMRCA of a pair of
lineages is exponentially distributed with rate :math:`1/(2N_e)`:

.. math::

   \log P(t \mid N_e) = \log\frac{1}{2N_e} - \frac{t}{2N_e}

For piecewise-constant :math:`N_e(t)`, the hazard is integrated across time
intervals (see :ref:`escapement_likelihood` for the full derivation).

**Entropy.** The sum of topology, branch-length, and breakpoint entropies, all
computed in closed form from the variational parameters.


Module 4: Demographic Inference
==================================

The fourth module parameterizes the effective population size trajectory
:math:`N_e(t)`. This is a learnable parameter, optimized jointly with the
variational posterior by maximizing the ELBO.

Escapement supports three parameterizations:

Piecewise-Constant
---------------------

The simplest option: :math:`N_e(t)` is constant within each of :math:`K`
time intervals. The parameters are :math:`K` values in log-space:

.. code-block:: python

   n_bins = 20
   log_Ne = nn.Parameter(torch.full((n_bins,), math.log(10000.0)))
   time_grid = torch.linspace(0, 200000, n_bins + 1)

   def get_Ne(t):
       return torch.exp(log_Ne)

This is the same parameterization as :ref:`PSMC <psmc_timepiece>`. The
log-space representation ensures positivity and allows the optimizer to work on
a natural scale (multiplicative changes in :math:`N_e` correspond to additive
changes in :math:`\log N_e`).

Neural Spline
----------------

For smoother trajectories, :math:`N_e(t)` can be parameterized as a monotonic
rational-quadratic spline in log-space. The spline knots are at fixed time
points, and the knot values and derivatives are predicted by a small MLP
conditioned on the latent representation:

.. math::

   \log N_e(t) = \text{RationalQuadraticSpline}(t \mid \mathbf{w}, \mathbf{h},
   \mathbf{s})

where :math:`\mathbf{w}` (widths), :math:`\mathbf{h}` (heights), and
:math:`\mathbf{s}` (slopes) are the spline parameters. This allows
:math:`N_e(t)` to vary smoothly while maintaining the flexibility to capture
sharp bottlenecks.

Gaussian Process
-------------------

For full Bayesian treatment, :math:`\log N_e(t)` can be modeled as a Gaussian
process with an RBF kernel. The GP posterior is approximated by variational
inducing points:

.. math::

   \log N_e(t) \sim \mathcal{GP}\!\left(m(t),\; k(t, t')\right), \qquad
   k(t, t') = \sigma^2 \exp\!\left(-\frac{(t - t')^2}{2\ell^2}\right)

The inducing-point parameters are optimized jointly with the ELBO. This
approach is inspired by :ref:`SINGER <singer_timepiece>`, which uses a GP
prior on branch lengths, and :ref:`phlash <phlash_timepiece>`, which uses
SVGD for Bayesian demographic inference.

.. admonition:: Joint optimization of φ and N_e

   The variational parameters :math:`\phi` (neural network weights) and the
   demographic parameters :math:`N_e(t)` are optimized jointly by maximizing
   the same ELBO objective. This is possible because the coalescent prior
   :math:`P(\tau \mid N_e)` depends on :math:`N_e(t)`, so the ELBO is a
   function of both :math:`\phi` and :math:`N_e`. In practice, the two sets
   of parameters use different learning rates: the neural network parameters
   use a standard rate (:math:`3 \times 10^{-4}`), while the :math:`N_e`
   parameters use a higher rate (:math:`10^{-2}`) because they are fewer and
   more directly constrained by the data.

   .. code-block:: python

      param_groups = [
          {"params": encoder_params, "lr": 3e-4},
          {"params": [log_Ne], "lr": 1e-2},
      ]
      optimizer = torch.optim.Adam(param_groups)


Putting It All Together
=========================

The complete Escapement model chains all four modules:

.. code-block:: python

   class Escapement(nn.Module):
       def __init__(self, d_model=64, n_heads=4, n_layers=2, window=64,
                    Ne=10000.0, mu=1.25e-8, rho=1e-8, span=100.0):
           super().__init__()
           self.encoder = GenealogyEncoder(d_model, n_heads, n_layers, window)
           self.var_posterior = VariationalTreePosterior(d_model, 2.0 * Ne)
           self.Ne = Ne
           self.mu = mu
           self.rho = rho
           self.span = span
           self.n_Ne_bins = 20
           self.log_Ne = nn.Parameter(
               torch.full((self.n_Ne_bins,), math.log(max(Ne, 1.0))))

       def forward(self, genotypes, temperature=1.0, hard=False):
           h = self.encoder(genotypes)
           posterior = self.var_posterior(h, temperature, hard)
           bt = posterior["branch_times"].clamp(min=1.0)

           log_lik = pairwise_mutation_loglik(
               genotypes, posterior["parent_probs"], bt, self.mu, self.span)
           log_prior = coalescent_log_prior_variable_Ne(
               bt, torch.exp(self.log_Ne), self.time_grid)
           log_prior_break = breakpoint_log_prior(
               posterior["break_probs"], self.rho, self.span)

           log_q_topo = posterior["log_q_topology"].sum(dim=(1, 2))
           entropy = posterior["entropy_branches"] + posterior["entropy_breaks"]

           elbo = log_lik + log_prior + log_prior_break + entropy - log_q_topo
           return {"elbo": elbo, **posterior}

       def loss(self, genotypes, temperature=1.0):
           return -self.forward(genotypes, temperature)["elbo"].mean()

.. list-table:: Computational complexity per module
   :header-rows: 1
   :widths: 22 28 50

   * - Module
     - Complexity
     - Bottleneck
   * - Genealogy Encoder
     - :math:`O(n^2 L d + n L w d)`
     - Sample attention (:math:`n^2` per position) + sliding window
   * - Variational Posterior
     - :math:`O(n^2 L d + n L d)`
     - Topology head (:math:`n^2` attention per position)
   * - Differentiable Likelihood
     - :math:`O(n^2 L + n L K)`
     - Pairwise diff (:math:`n^2`) + variable-:math:`N_e` prior (:math:`K` bins)
   * - Demographic Inference
     - :math:`O(K)`
     - Trivial (just :math:`K` exponentiated parameters)

Total: :math:`O(n^2 L d)`, dominated by the attention mechanisms. For typical
applications (:math:`n \leq 50`, :math:`L \sim 500`, :math:`d = 64`), each
ELBO evaluation takes ~10 ms on a modern GPU.
