.. _escapement_training:

======================
Training on Real Data
======================

   *Winding a mainspring requires an external key. Regulating an escapement
   requires only the mechanism itself -- you adjust the regulator screw, observe
   the rate, and iterate. The data is both the input and the standard.*

This chapter describes how Escapement is trained. The distinction from
:ref:`Mainspring <mainspring_training>` is fundamental: there is no simulation
engine, no curriculum of increasing complexity, no simulated ground truth. There
is only the observed genotype matrix :math:`\mathbf{D}` and the coalescent
likelihood. The training loop optimizes the ELBO on real data.


The Training Loop
===================

Escapement's training loop is conceptually simple:

1. Encode: pass :math:`\mathbf{D}` through the genealogy encoder to produce
   latent vectors :math:`\mathbf{h}`.
2. Sample: draw a genealogy :math:`\tau \sim q(\tau \mid \mathbf{D}, \phi)`
   from the variational posterior.
3. Evaluate: compute the three ELBO terms (mutation likelihood, coalescent
   prior, entropy) using the differentiable likelihood.
4. Maximize: backpropagate through the ELBO and update :math:`\phi` (network
   weights) and :math:`\theta` (:math:`N_e(t)` parameters).

No msprime. No ground truth. No simulated ARGs. The ELBO is the only loss.

.. code-block:: python

   import torch
   from torch.optim import Adam

   def train_escapement(model, genotypes, n_steps=2000, lr_encoder=3e-4,
                        lr_Ne=1e-2, device='cuda'):
       """Train Escapement on a single observed genotype matrix."""
       genotypes = genotypes.to(device)
       model = model.to(device)

       param_groups = model.get_param_groups(lr_encoder, lr_Ne)
       optimizer = Adam(param_groups)

       for step in range(n_steps):
           temperature = anneal_temperature(step, n_steps)
           loss = model.loss(genotypes, temperature=temperature)
           optimizer.zero_grad()
           loss.backward()
           torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
           optimizer.step()

           if step % 100 == 0:
               out = model(genotypes, temperature=temperature)
               print(f"Step {step:4d} | ELBO {out['elbo'].mean():.1f} | "
                     f"loglik {out['log_lik'].mean():.1f} | "
                     f"prior {out['log_prior_coal'].mean():.1f} | "
                     f"H {out['entropy'].mean():.1f} | "
                     f"tau {temperature:.3f}")

       return model

.. admonition:: Per-dataset optimization

   Unlike Mainspring, which trains once and runs on many datasets, Escapement
   optimizes separately for each dataset. This is a feature, not a bug: the
   variational posterior is tailored to the specific data, not averaged across
   a training distribution. The cost is time (~10--30 minutes per dataset on
   a GPU). The benefit is that the posterior reflects the actual data, not a
   simulation prior.


Temperature Annealing for Gumbel-Softmax
===========================================

The Gumbel-softmax temperature :math:`\tau` controls the sharpness of the
topology assignments. At high temperature, parent assignments are nearly
uniform (maximum entropy). At low temperature, they approach one-hot vectors
(deterministic topology).

Escapement uses an exponential annealing schedule:

.. math::

   \tau(t) = \tau_{\max} \cdot \left(\frac{\tau_{\min}}{\tau_{\max}}
   \right)^{t/T}

where :math:`t` is the optimization step, :math:`T` is the total number of
steps, :math:`\tau_{\max} = 1.0`, and :math:`\tau_{\min} = 0.1`.

.. code-block:: python

   def anneal_temperature(step, total_steps, tau_max=1.0, tau_min=0.1):
       return tau_max * (tau_min / tau_max) ** (step / total_steps)

The annealing schedule serves two purposes:

1. **Exploration.** At high temperature, the variational posterior explores a
   wide range of topologies. The ELBO gradients push the posterior toward
   promising regions without committing to a single topology too early.

2. **Exploitation.** At low temperature, the posterior concentrates on the
   best topology found during exploration. The branch lengths and breakpoints
   are refined under this near-deterministic topology.

.. list-table:: Temperature annealing phases
   :header-rows: 1
   :widths: 20 20 60

   * - Phase
     - Temperature
     - Behavior
   * - Early (steps 0--500)
     - :math:`\tau \approx 1.0`
     - Soft parent assignments. Topology is uncertain. Gradients flow easily.
       Branch-length scale is calibrated.
   * - Middle (steps 500--1500)
     - :math:`\tau \approx 0.3\text{--}0.5`
     - Parent assignments sharpen. Topology structure emerges. Breakpoints
       begin to localize.
   * - Late (steps 1500--2000)
     - :math:`\tau \approx 0.1`
     - Near-deterministic topology. Fine-tuning of branch lengths, breakpoints,
       and :math:`N_e(t)`. ELBO converges.


Warm-Starting
===============

Escapement's optimization landscape is multi-modal: many different genealogies
can explain the same genotype matrix reasonably well. Starting from a random
initialization risks getting trapped in a poor local optimum, especially for
the discrete topology.

Two warm-starting strategies dramatically improve convergence:

From Mainspring
-----------------

The recommended approach is to initialize Escapement from
:ref:`Mainspring's <mainspring_complication>` output. Mainspring provides a fast
(~1 second), approximate ARG. Escapement then refines this ARG using the
coalescent likelihood.

.. code-block:: python

   def warm_start_from_mainspring(escapement_model, mainspring_model,
                                  genotypes):
       """Initialize Escapement's encoder from Mainspring's output."""
       with torch.no_grad():
           mainspring_out = mainspring_model(genotypes, hard=True)

       escapement_model.encoder.load_state_dict(
           mainspring_model.encoder.state_dict(), strict=False)

       with torch.no_grad():
           ms_times = mainspring_out['times']
           mean_log_t = torch.log(ms_times.clamp(min=1.0))
           escapement_model.var_posterior.branch_lengths.mlp[-1].bias[0] = (
               mean_log_t.mean().item())

       return escapement_model

This is the hybrid pipeline described in :ref:`mainspring_comparison`. The
analogy to horology is precise: the mainspring provides the initial energy (a
good starting point), and the escapement regulates it into precise, calibrated
motion (a principled posterior).

From tsinfer
--------------

An alternative warm-start uses :ref:`tsinfer <tsinfer_timepiece>` to provide
an initial topology estimate. Since tsinfer scales to much larger sample sizes
than Escapement, this is useful when Mainspring is not available:

.. code-block:: python

   def warm_start_from_tsinfer(escapement_model, ts_inferred, genotypes):
       """Initialize branch-length parameters from tsinfer topology."""
       import tskit
       import numpy as np

       pairwise_div = np.zeros((genotypes.shape[1], genotypes.shape[1]))
       G = genotypes[0].numpy()
       for i in range(G.shape[0]):
           for j in range(i + 1, G.shape[0]):
               pairwise_div[i, j] = np.mean(G[i] != G[j])
               pairwise_div[j, i] = pairwise_div[i, j]

       mean_div = pairwise_div[np.triu_indices_from(pairwise_div, k=1)].mean()
       estimated_tmrca = mean_div / (2 * escapement_model.mu * escapement_model.span)

       with torch.no_grad():
           escapement_model.var_posterior.branch_lengths.mlp[-1].bias[0] = (
               math.log(max(estimated_tmrca, 1.0)))

       return escapement_model


Joint vs. Alternating Optimization
=====================================

Escapement optimizes two sets of parameters:

- :math:`\phi`: neural network weights (encoder + variational posterior heads)
- :math:`\theta`: demographic parameters (:math:`N_e(t)`)

These can be optimized **jointly** (single optimizer, same gradient step) or
**alternating** (update :math:`\phi` for :math:`K_1` steps, then :math:`\theta`
for :math:`K_2` steps, repeat).

Joint Optimization
--------------------

The simplest approach: both parameter sets share the same ELBO objective and
are updated simultaneously with different learning rates.

.. code-block:: python

   param_groups = [
       {"params": encoder_params, "lr": 3e-4},
       {"params": [log_Ne], "lr": 1e-2},
   ]
   optimizer = Adam(param_groups)

This works well in practice because the :math:`N_e(t)` parameters are few
(typically 20) and directly constrained by the coalescent prior. The different
learning rates accommodate the different scales: the neural network needs small
steps to avoid destabilizing the latent representations, while :math:`N_e(t)`
can take larger steps because the coalescent prior provides a strong signal.

Alternating Optimization
--------------------------

For difficult cases (complex demography, many samples), alternating optimization
can be more stable:

.. code-block:: python

   def alternating_train(model, genotypes, n_outer=100,
                         n_phi=10, n_theta=5):
       opt_phi = Adam([p for n, p in model.named_parameters()
                       if n != "log_Ne"], lr=3e-4)
       opt_theta = Adam([model.log_Ne], lr=1e-2)

       for outer in range(n_outer):
           temp = anneal_temperature(outer, n_outer)

           for _ in range(n_phi):
               loss = model.loss(genotypes, temp)
               opt_phi.zero_grad()
               loss.backward()
               torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
               opt_phi.step()

           for _ in range(n_theta):
               loss = model.loss(genotypes, temp)
               opt_theta.zero_grad()
               loss.backward()
               opt_theta.step()

The rationale: when :math:`\phi` is updated with fixed :math:`N_e`, the network
learns to propose genealogies consistent with the current demography. When
:math:`N_e` is updated with fixed :math:`\phi`, the demography adjusts to
match the genealogies currently being proposed. This alternation can help
escape saddle points where joint optimization stalls.


Variance Reduction for Discrete Gradients
============================================

The Gumbel-softmax provides biased but low-variance gradient estimates for the
discrete topology. For the highest-quality gradients, Escapement can optionally
use advanced variance reduction techniques.

NVIL (Neural Variational Inference and Learning)
----------------------------------------------------

NVIL (Mnih & Gregor 2014) uses a learned baseline to reduce the variance of the
REINFORCE gradient for discrete variables:

.. math::

   \nabla_\phi \mathbb{E}_q[f(\tau)] \approx (f(\tau) - c_\psi(\mathbf{h}))
   \cdot \nabla_\phi \log q(\tau \mid \phi)

where :math:`c_\psi(\mathbf{h})` is a neural baseline that predicts the
expected ELBO from the latent representation, trained to minimize
:math:`(f(\tau) - c_\psi(\mathbf{h}))^2`.

.. code-block:: python

   class NVILBaseline(nn.Module):
       def __init__(self, d_model):
           super().__init__()
           self.net = nn.Sequential(
               nn.Linear(d_model, d_model), nn.ReLU(),
               nn.Linear(d_model, 1))

       def forward(self, h):
           return self.net(h.mean(dim=(1, 2))).squeeze(-1)

RELAX
-------

RELAX (Grathwohl et al. 2018) combines the Gumbel-softmax with a control
variate that uses both the discrete and relaxed samples:

.. math::

   \nabla_\phi = (f(\text{hard}) - c_\psi(\text{soft})) \cdot
   \nabla_\phi \log q + \nabla_\phi c_\psi(\text{soft})
   - \nabla_\phi c_\psi(\tilde{\text{soft}})

where :math:`\text{hard}` is the discrete sample, :math:`\text{soft}` is the
Gumbel-softmax relaxation, and :math:`\tilde{\text{soft}}` is a conditional
relaxation. RELAX provides unbiased, low-variance gradient estimates but is
more complex to implement.

In practice, the Gumbel-softmax with straight-through estimation is sufficient
for most applications. NVIL and RELAX are recommended only when the
optimization fails to converge (typically for large :math:`n` or complex
demographies).


Practical Considerations
==========================

Window Size
-------------

Escapement processes the genotype matrix in windows along the genome. The window
size :math:`w` (in number of segregating sites) controls the trade-off between
local accuracy and computational cost:

.. list-table::
   :header-rows: 1
   :widths: 25 40 35

   * - Window size
     - Behavior
     - When to use
   * - Small (:math:`w \leq 32`)
     - Captures only very local LD. May miss long-range haplotype sharing.
     - Very high recombination rate; testing and debugging.
   * - Medium (:math:`w = 64\text{--}128`)
     - Good balance. Captures LD within ~1 expected tree span.
     - Default for most applications.
   * - Large (:math:`w \geq 256`)
     - Captures long-range LD. High memory cost (:math:`O(w^2)` per sample).
     - Low recombination rate; large tree spans.

A reasonable default is :math:`w \approx 1/\hat{\rho}`, where :math:`\hat{\rho}`
is the recombination rate in units of recombinations per segregating site.

Batch Construction
--------------------

For long genomes, Escapement processes non-overlapping or overlapping windows
and averages the ELBO contributions. With overlapping windows (stride < window
size), the predictions in the overlap region are averaged, providing smoother
breakpoint estimates:

.. code-block:: python

   def process_genome_in_windows(model, genotypes, window=256,
                                 stride=192, temperature=1.0):
       B, N, L = genotypes.shape
       total_elbo = 0.0
       n_windows = 0

       for start in range(0, L - window + 1, stride):
           chunk = genotypes[:, :, start:start + window]
           out = model(chunk, temperature=temperature)
           total_elbo += out['elbo'].sum()
           n_windows += 1

       return total_elbo / n_windows

Convergence Monitoring
------------------------

Since there is no ground truth, convergence must be monitored through the ELBO
itself and its components:

.. list-table:: Convergence diagnostics
   :header-rows: 1
   :widths: 30 70

   * - Diagnostic
     - What it tells you
   * - **ELBO trajectory**
     - Should increase monotonically (on average). Stalling indicates a local
       optimum or learning rate issue.
   * - **Mutation log-likelihood**
     - Should increase as the proposed genealogy better explains the data. If
       it plateaus early, the topology may be stuck.
   * - **Coalescent log-prior**
     - Should increase as branch lengths become consistent with :math:`N_e(t)`.
       If it decreases while the likelihood increases, the branch lengths are
       being pulled away from the coalescent prior (tension between data and
       model).
   * - **Entropy**
     - Should decrease during annealing (topology sharpens). If it stays high,
       the model is uncertain about the topology -- possibly because the data
       is not informative enough.
   * - **:math:`N_e(t)` trajectory**
     - Plot :math:`N_e(t)` every 100 steps. It should stabilize. Oscillations
       indicate learning rate is too high.

.. code-block:: python

   def monitor_convergence(model, genotypes, step, history):
       with torch.no_grad():
           out = model(genotypes, temperature=0.5)
       history['elbo'].append(out['elbo'].mean().item())
       history['loglik'].append(out['log_lik'].mean().item())
       history['prior'].append(out['log_prior_coal'].mean().item())
       history['entropy'].append(out['entropy'].mean().item())
       history['Ne'].append(model.get_Ne().cpu().numpy().copy())
       return history

.. admonition:: When to stop

   Stop training when:

   1. The ELBO has not improved by more than 0.1% over the last 200 steps.
   2. The :math:`N_e(t)` trajectory has stabilized (relative change < 1% per
      100 steps).
   3. The Gumbel-softmax temperature has reached :math:`\tau_{\min}`.

   Typical convergence: 1,000--3,000 steps for simple demography (constant
   :math:`N_e`), 3,000--10,000 steps for complex demography (multiple
   bottlenecks). With warm-starting from Mainspring, these numbers drop by
   a factor of 3--5.


A Complete Training Example
==============================

Putting it all together: training Escapement on a genotype matrix from a
population with a bottleneck.

.. code-block:: python

   import torch
   import math
   from model import Escapement

   torch.manual_seed(42)
   n_samples, n_sites = 20, 200
   genotypes = torch.bernoulli(torch.full((1, n_samples, n_sites), 0.15))

   model = Escapement(
       d_model=64, n_heads=4, n_layers=2, window=64,
       Ne=10_000, mu=1.25e-8, rho=1e-8, span=100.0)

   time_grid = torch.tensor([0, 2000, 5000, 10000, 20000, 50000, 1e8])
   model.log_Ne = torch.nn.Parameter(torch.full((6,), math.log(10000.0)))
   model.n_Ne_bins = 6
   model.set_time_grid(time_grid)

   optimizer = torch.optim.Adam(model.get_param_groups(lr_encoder=3e-4,
                                                        lr_Ne=1e-2))

   for step in range(2000):
       temp = max(0.1, 1.0 - step / 2000)
       loss = model.loss(genotypes, temperature=temp)
       optimizer.zero_grad()
       loss.backward()
       torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
       optimizer.step()

   results = model.infer(genotypes)
   ne_trajectory = model.get_Ne()
   print("Inferred N_e(t):", ne_trajectory.numpy())
   print("Final ELBO:", results['elbo'].mean().item())

No simulations were harmed in the making of this inference.
