The Differentiable Likelihood
The balance spring is pure physics. No mechanism, no escapement, no gear train – just the restoring force of an elastic strip, precisely calibrated to oscillate at a known frequency. It is the simplest component in the watch, and the most critical.
Module 3 of Escapement contains no neural networks. It is the mathematical core of the system: given a sampled genealogy \(\tau \sim q(\tau \mid \mathbf{D}, \phi)\), it computes three scalar quantities whose sum is the ELBO. These three quantities – the mutation log-likelihood, the coalescent log-prior, and the variational entropy – are the same equations derived across every Timepiece in this book, now assembled into a single differentiable objective.
This chapter derives each term from first principles, connects each to its source Timepiece, shows how gradients flow through the reparameterization trick, and provides the complete PyTorch implementation.
Mutation Log-Likelihood
The mutation log-likelihood answers: given the proposed genealogy \(\tau\), how well does it explain the observed data \(\mathbf{D}\)?
Derivation
Under the infinite-sites mutation model, mutations accumulate on each edge of the genealogy as a Poisson process with rate \(\mu\) per base pair per generation. For an edge connecting sample \(i\) to its coalescent ancestor \(j\) with TMRCA \(t_{i,\ell}\) at position \(\ell\), spanning \(s\) base pairs, the expected number of mutations on both lineages from leaves to MRCA is:
The factor of 2 accounts for mutations on both the \(i \to \text{MRCA}\) and \(j \to \text{MRCA}\) branches. The probability that \(i\) and \(j\) differ at this position (at least one mutation occurred) is:
This is the Jukes-Cantor two-allele model, a standard approximation used in tsdate and ARGweaver.
The log-likelihood of a single observation \(d_{i,\ell}\) given the proposed parent \(j\) and coalescence time \(t_{i,\ell}\) is a Bernoulli log-likelihood on the mismatch indicator:
where \(m_{i,\ell} = |d_{i,\ell} - d_{j,\ell}|\) is the observed mismatch (0 or 1) and \(p_{i,\ell} = 1 - e^{-2\mu t_{i,\ell} s}\) is the predicted mismatch probability.
Soft Parent Assignments
Because the topology is represented as soft parent probabilities \(\alpha_{ij}^\ell\) (from the Gumbel-softmax), the “parent genotype” is a probability-weighted average:
and the “mismatch” is the soft absolute difference \(\tilde{m}_{i,\ell} = |d_{i,\ell} - \hat{d}_{j(i),\ell}|\). This allows gradients to flow through the topology via the Gumbel-softmax.
def pairwise_mutation_loglik(genotypes, parent_probs, branch_times,
mu, span=1.0):
B, N, L = genotypes.shape
parent_geno = torch.einsum("blij,bjl->bil", parent_probs, genotypes)
obs_diff = (genotypes - parent_geno).abs()
rate = (2.0 * mu * branch_times * span).clamp(max=20.0)
p_mismatch = (1.0 - torch.exp(-rate)).clamp(1e-7, 1.0 - 1e-7)
loglik = (obs_diff * torch.log(p_mismatch)
+ (1.0 - obs_diff) * torch.log(1.0 - p_mismatch))
return loglik.sum(dim=(1, 2))
The clamp(max=20.0) on the rate prevents numerical overflow in
exp(-rate) for very long branches. The clamp(1e-7, 1.0 - 1e-7) on the
mismatch probability prevents log(0).
Connection to tsdate’s mutation likelihood
tsdate derives the same Poisson mutation model but applies it to a fixed tree topology with known parent assignments. The edge log-likelihood in tsdate is:
Escapement uses a simpler Bernoulli approximation (mismatch vs. no mismatch) rather than counting mutations per edge, because the parent assignment is soft (probabilistic) rather than hard (deterministic). When the Gumbel-softmax temperature is low and the parent assignment approaches one-hot, the two formulations converge.
Auxiliary Pairwise-Difference Likelihood
In addition to the primary parent-based likelihood, Escapement includes an auxiliary pairwise-difference log-likelihood that provides a global signal for branch-length calibration:
where \(m_{ij,\ell} = |d_{i,\ell} - d_{j,\ell}|\) and the pairwise TMRCA is estimated as \(\hat{t}_{ij} = (t_i + t_j)/2\). This term is weighted by a small coefficient (typically 0.1) and provides a signal that does not depend on the topology – only on the overall scale of branch lengths.
def pairwise_diff_loglik(genotypes, branch_times, mu, span=1.0):
B, N, L = genotypes.shape
diff = (genotypes.unsqueeze(2) - genotypes.unsqueeze(1)).abs()
t_pair = (branch_times.unsqueeze(2) + branch_times.unsqueeze(1)) / 2.0
rate_pair = (2.0 * mu * t_pair * span).clamp(max=20.0)
p_mismatch = (1.0 - torch.exp(-rate_pair)).clamp(1e-7, 1.0 - 1e-7)
loglik = (diff * torch.log(p_mismatch)
+ (1.0 - diff) * torch.log(1.0 - p_mismatch))
mask = torch.triu(torch.ones(N, N, device=genotypes.device), diagonal=1)
return (loglik * mask.unsqueeze(0).unsqueeze(-1)).sum(dim=(1, 2, 3))
Coalescent Log-Prior
The coalescent log-prior answers: is the proposed genealogy plausible under the coalescent model with the proposed \(N_e(t)\)?
Constant \(N_e\)
For a pair of lineages under the Kingman coalescent with constant effective population size \(N_e\), the TMRCA is exponentially distributed:
The log-prior for a single coalescence time \(t\) is:
For a complete genealogy with \(n\) samples, there are \(n - 1\) coalescence events. Under Escapement’s factored approximation, the prior is the product over all sample-position pairs:
def coalescent_log_prior(branch_times, Ne=10000.0):
rate = 1.0 / (2.0 * Ne)
log_prior = math.log(rate) - rate * branch_times
return log_prior.sum(dim=(1, 2))
Connection to msprime and PSMC
The exponential distribution of pairwise coalescence times is derived in msprime as the foundation of the Kingman coalescent. PSMC uses the same distribution but discretizes it into time intervals for the HMM. Escapement uses the continuous distribution directly, avoiding discretization artifacts.
Piecewise-Constant \(N_e(t)\)
For realistic demography, \(N_e(t)\) varies through time. With piecewise-constant \(N_e(t)\) on a grid \(0 = t_0 < t_1 < \cdots < t_K\), the coalescent rate in interval \(k\) is \(\lambda_k = 1/(2 N_e^{(k)})\). The log-prior of a coalescence time \(t\) involves integrating the hazard:
where \(\lambda(t)\) is the instantaneous rate at the coalescence time and the integral is the cumulative hazard:
Here \(\Delta t_k = t_{k+1} - t_k\) is the width of interval \(k\). The coalescence time \(t\) contributes \(\lambda_k \Delta t_k\) for each interval it fully spans, and a partial contribution \(\lambda_k (t - t_k)\) for the interval in which it falls.
def coalescent_log_prior_variable_Ne(branch_times, Ne_fn, time_grid):
K = Ne_fn.shape[0]
Ne_clamped = Ne_fn.clamp(min=1.0)
cumulative_hazard = torch.zeros_like(branch_times)
t_remaining = branch_times.clone()
instantaneous_rate = torch.zeros_like(branch_times)
for k in range(K):
rate_k = 1.0 / (2.0 * Ne_clamped[k])
if k < K - 1:
dt_interval = time_grid[k + 1] - time_grid[k]
else:
dt_interval = t_remaining.max().detach().item() + 1.0
dt_used = torch.clamp(t_remaining, max=dt_interval)
cumulative_hazard = cumulative_hazard + rate_k * dt_used
in_this_bin = (t_remaining > 0) & (t_remaining <= dt_interval)
instantaneous_rate = torch.where(
in_this_bin, rate_k, instantaneous_rate)
t_remaining = (t_remaining - dt_used).clamp(min=0.0)
instantaneous_rate = instantaneous_rate.clamp(min=1e-30)
log_prior = torch.log(instantaneous_rate) - cumulative_hazard
return log_prior.sum(dim=(1, 2))
Gradient flow through N_e
The coalescent log-prior is differentiable with respect to
\(N_e^{(k)}\) (through the Ne_fn tensor). This is what enables joint
optimization of the demography: when the ELBO is maximized, the gradients
through the coalescent prior push \(N_e(t)\) toward values consistent
with the proposed coalescence times, while the gradients through the mutation
likelihood push the coalescence times toward values consistent with the
observed data. The two signals converge to a self-consistent estimate of
both the genealogy and the demography.
Breakpoint Log-Prior
Under the Sequential Markov Coalescent, recombination breakpoints between adjacent sites occur as a Poisson process with rate \(\rho\) per base pair per generation. For a segment of span \(s\) base pairs, the probability of at least one recombination event (and hence a tree change) is:
The breakpoint log-prior for the predicted breakpoint probabilities \(b_\ell\) is:
def breakpoint_log_prior(break_probs, rho, span=1.0):
p_break = max(1.0 - math.exp(-rho * span), 1e-8)
p_no_break = 1.0 - p_break
bp = break_probs.clamp(1e-8, 1.0 - 1e-8)
log_prior = bp * math.log(p_break) + (1.0 - bp) * math.log(p_no_break)
return log_prior.sum(dim=1)
This term acts as a regularizer: it penalizes predicted breakpoint probabilities that are inconsistent with the recombination rate. For typical human recombination rates (\(\rho \approx 10^{-8}\) per bp per generation) and typical inter-site spacing (\(s \approx 100\) bp), the prior breakpoint probability is very low (\(\approx 10^{-6}\)), encouraging the model to predict few breakpoints.
Entropy Decomposition
The entropy \(H[q]\) of the variational posterior decomposes into three independent terms due to the mean-field factorization:
Topology Entropy (Categorical)
The topology at each position is a categorical distribution over parent assignments:
This is computed by the topology head as chosen_log_probs. In the ELBO, the
topology term appears as \(-\mathbb{E}_q[\log q(\text{topology})]\):
Branch-Length Entropy (Log-Normal)
The branch lengths are log-normally distributed with parameters \(\mu_{i,\ell}\) and \(\sigma_{i,\ell}\). The log-normal entropy has a closed-form expression:
The total branch-length entropy sums over all sample-position pairs:
def lognormal_entropy(log_mean, log_std):
return log_mean + 0.5 + log_std + 0.5 * math.log(2 * math.pi)
Why log-normal and not gamma?
tsdate and Gamma-SMC use gamma distributions for coalescence-time posteriors, which is the conjugate choice for the exponential coalescent prior. Escapement uses log-normal instead, for a practical reason: the log-normal reparameterization trick (\(t = \exp(\mu + \sigma\epsilon)\), \(\epsilon \sim \mathcal{N}(0,1)\)) produces lower-variance gradient estimates than the gamma reparameterization.
The gamma distribution can be reparameterized (Figurnov et al. 2018), but the resulting gradient estimates have higher variance, especially for small shape parameters. Since Escapement’s optimization is already challenging (discrete topology, multi-modal landscape), we prioritize gradient quality over distributional faithfulness.
Breakpoint Entropy (Bernoulli)
The breakpoint at each position is Bernoulli-distributed:
The total breakpoint entropy is:
def bernoulli_entropy(p, eps=1e-8):
p = p.clamp(eps, 1.0 - eps)
return -(p * p.log() + (1.0 - p) * (1.0 - p).log())
Gradient Flow via the Reparameterization Trick
The ELBO is an expectation over the variational posterior:
where \(f(\tau, \theta) = \log P(\mathbf{D} \mid \tau, \mu) + \log P(\tau \mid N_e, \rho) + H[q]\). To optimize \(\phi\) by gradient descent, we need \(\nabla_\phi \text{ELBO}\). The reparameterization trick rewrites the expectation so that the randomness is independent of \(\phi\):
Branch lengths. \(t_{i,\ell} = \exp(\mu_{i,\ell} + \sigma_{i,\ell} \cdot \epsilon)\), where \(\epsilon \sim \mathcal{N}(0,1)\) is fixed noise. The gradient \(\nabla_\phi t_{i,\ell}\) flows through \(\mu\) and \(\sigma\) to the encoder.
Breakpoints. Bernoulli variables are discrete, so the standard reparameterization trick does not apply. Instead, Escapement uses the Gumbel-sigmoid relaxation:
Topology. The Gumbel-softmax trick (described in Variational Inference Without Simulations) provides gradients through the discrete parent assignments. The straight-through estimator uses hard assignments in the forward pass and soft assignments in the backward pass.
Gradient flow in Escapement:
┌──────────────────────────────────────────────────────────────┐
│ │
│ ELBO = log P(D|τ,μ) + log P(τ|Ne,ρ) + H[q] │
│ │ │
│ ├── ∂/∂t (branch lengths) ← reparameterization trick │
│ │ └── ∂/∂μ, ∂/∂σ ← backprop through encoder │
│ │ │
│ ├── ∂/∂π (topology) ← Gumbel-softmax STE │
│ │ └── ∂/∂α (logits) ← backprop through encoder │
│ │ │
│ ├── ∂/∂b (breakpoints) ← Gumbel-sigmoid │
│ │ └── ∂/∂logit ← backprop through encoder │
│ │ │
│ └── ∂/∂Ne (demography) ← direct gradient │
│ └── through coalescent prior │
│ │
└──────────────────────────────────────────────────────────────┘
All three sources of randomness (Gaussian for branch lengths, Gumbel for topology, Gumbel for breakpoints) are sampled independently of \(\phi\), so the gradient estimator \(\nabla_\phi f(\tau(\phi, \epsilon), \theta)\) is unbiased and low-variance. This is the same reparameterization trick used in variational autoencoders (Kingma & Welling 2014), extended to the structured latent space of tree sequences.
Connection to Source Timepieces
Every component of the differentiable likelihood can be traced to a specific Timepiece:
Component |
Formula |
Source Timepiece |
What Escapement changes |
|---|---|---|---|
Poisson mutation model |
\(P(\text{diff}) = 1 - e^{-2\mu t s}\) |
tsdate (mutation likelihood) |
Soft parent assignments instead of fixed topology |
Exponential coalescent |
\(T \sim \text{Exp}(1/(2N_e))\) |
msprime (coalescent theory) |
Continuous \(N_e(t)\), jointly optimized |
Piecewise-constant hazard |
\(\int_0^t \lambda(s) ds\) |
PSMC (discretized coalescent) |
Direct integration, no discretization of \(t\) |
SMC breakpoint model |
\(P(\text{break}) = 1 - e^{-\rho s}\) |
PSMC (SMC approximation) |
Learned breakpoint detection, not HMM transitions |
Gamma/log-normal times |
\(t \sim \text{LogNormal}(\mu, \sigma)\) |
Neural parameterization instead of hand-derived EP |
|
Factored prior |
\(P(\tau) \approx \prod_\ell P(\mathcal{T}_\ell)\) |
PSMC (SMC factorization) |
Same approximation, used in variational objective |
The differentiable likelihood is not a new population-genetic model. It is a compilation of existing models into a form that supports automatic differentiation. The Timepieces provide the equations; Escapement provides the gradient infrastructure.