.. _balance_wheel_posterior:

==============================
Posterior Inference via HMC
==============================

   *The chronometer test does not ask whether a watch keeps perfect time. It
   asks whether the watchmaker knows how imperfect the time is. A movement rated
   to ±2 seconds per day is more trustworthy than one claimed to be exact -- the
   first acknowledges its uncertainty, the second conceals it.*

:ref:`dadi <dadi_timepiece>` and :ref:`moments <moments_timepiece>` find the
maximum likelihood estimate (MLE) of demographic parameters. They report a
point estimate -- the single :math:`\hat{\Theta}` that maximizes the Poisson
log-likelihood -- and, at best, profile likelihood intervals or AIC scores for
model comparison. They do not produce a full posterior distribution over
:math:`\Theta`.

The reason is computational: posterior sampling requires thousands of likelihood
evaluations, and each evaluation costs 10--100 ms with moments or dadi. For
Hamiltonian Monte Carlo (HMC), which needs gradients, the cost per step is even
higher: 10 gradient evaluations × 10--100 ms × 2 (for finite differences in
dadi) = 200 ms to 2 s per HMC step. Running 10,000 HMC steps would take hours
to days.

Balance Wheel changes this arithmetic. At 0.1 ms per evaluation with exact
backpropagation gradients, each HMC step costs ~1 ms. Ten thousand steps take
10 seconds. Full Bayesian posterior inference becomes routine.


Why Bayesian Inference Matters
================================

The MLE :math:`\hat{\Theta}` tells you the most likely parameter values, but
not how confident you should be. Two datasets might yield the same MLE but very
different posterior widths -- one tightly constraining the bottleneck time, the
other barely identifying it. Without the posterior, you cannot distinguish these
cases.

Bayesian inference gives you:

1. **Credible intervals** on each parameter. The 95% credible interval
   :math:`[a, b]` means: given the data and the model, the parameter falls in
   :math:`[a, b]` with 95% posterior probability.

2. **Posterior correlations.** Demographic parameters are often correlated --
   e.g., bottleneck depth and duration trade off against each other. The joint
   posterior reveals these correlations; the MLE hides them.

3. **Posterior predictive checks.** Sample parameters from the posterior,
   compute the expected SFS for each sample, and compare to the observed SFS.
   Systematic discrepancies indicate model misspecification.

4. **Model comparison via marginal likelihood.** The marginal likelihood
   :math:`P(\mathbf{D} \mid \mathcal{M})` integrates the likelihood over the
   prior -- a principled metric for comparing models of different complexity
   (e.g., two-epoch vs. three-epoch models).


The Log-Posterior
===================

The posterior is proportional to the likelihood times the prior:

.. math::

   \log p(\Theta \mid \mathbf{D}) = \underbrace{\ell(\Theta)}_{\text{Poisson
   log-likelihood}} + \underbrace{\log \pi(\Theta)}_{\text{prior}}
   - \underbrace{\log P(\mathbf{D})}_{\text{evidence (constant)}}

The log-likelihood :math:`\ell(\Theta)` is evaluated through Balance Wheel's
neural SFS Predictor. The prior :math:`\pi(\Theta)` encodes our beliefs about
plausible demographic parameters before seeing the data.

Priors on demographic parameters
-----------------------------------

We use weakly informative priors that constrain parameters to physically
plausible ranges without being overly prescriptive:

.. math::

   \log(N_e / N_{\text{ref}}) &\sim \mathcal{N}(0, 2) \\
   \log(t_k) &\sim \mathcal{N}(\mu_t, 1) \\
   m_{ij} &\sim \text{Exponential}(10)

where :math:`N_{\text{ref}}` is a reference population size and :math:`\mu_t`
is a prior mean for the log-time (typically set to the expected TMRCA).

.. code-block:: python

   import torch
   from torch.distributions import Normal, Exponential

   def log_prior(log_sizes, log_times):
       """Weakly informative prior on demographic parameters."""
       lp = Normal(0, 2).log_prob(log_sizes).sum()
       lp += Normal(0, 1).log_prob(log_times).sum()
       return lp

   def log_posterior(params, observed_sfs, model, theta_L):
       """Unnormalized log-posterior for HMC sampling."""
       log_sizes, log_times = params[:len(params)//2], params[len(params)//2:]
       log_times_sorted = torch.sort(log_times)[0]
       log_lik = model.log_likelihood(
           log_sizes.unsqueeze(0), log_times_sorted.unsqueeze(0),
           observed_sfs, theta_L)
       lp = log_prior(log_sizes, log_times_sorted)
       return log_lik + lp


Speed Comparison
==================

The advantage of Balance Wheel for posterior sampling is quantitative:

.. list-table:: Cost per HMC step (10 leapfrog steps, 8 parameters)
   :header-rows: 1
   :widths: 25 20 20 20 15

   * - Method
     - Per eval
     - Gradient
     - Per HMC step
     - 10k steps
   * - dadi (finite diff)
     - 100 ms
     - :math:`2 \times 8 \times 100` ms
     - 16 s
     - 44 hours
   * - moments (AD)
     - 10 ms
     - 20 ms
     - 200 ms
     - 33 min
   * - **Balance Wheel**
     - 0.1 ms
     - 0.2 ms
     - **2 ms**
     - **20 s**

Balance Wheel is 100× faster than moments and 8,000× faster than dadi for HMC
sampling. This is the difference between "theoretically possible but never done"
and "run it during a coffee break."


HMC/NUTS Implementation
==========================

Hamiltonian Monte Carlo (see :ref:`mcmc` for the prerequisite) samples from the
posterior by simulating Hamiltonian dynamics on the log-posterior surface. The
No-U-Turn Sampler (NUTS) automatically tunes the number of leapfrog steps,
eliminating the need for manual tuning.

The key requirement is a differentiable log-posterior -- which Balance Wheel
provides via backpropagation through the neural SFS Predictor.

.. code-block:: python

   import torch

   class HMCSampler:
       def __init__(self, log_prob_fn, step_size=0.01, n_leapfrog=10):
           self.log_prob_fn = log_prob_fn
           self.step_size = step_size
           self.n_leapfrog = n_leapfrog

       def _leapfrog(self, q, p):
           q = q.detach().requires_grad_(True)
           log_prob = self.log_prob_fn(q)
           grad = torch.autograd.grad(log_prob, q)[0]
           p = p + 0.5 * self.step_size * grad

           for _ in range(self.n_leapfrog - 1):
               q = q + self.step_size * p
               q = q.detach().requires_grad_(True)
               log_prob = self.log_prob_fn(q)
               grad = torch.autograd.grad(log_prob, q)[0]
               p = p + self.step_size * grad

           q = q + self.step_size * p
           q = q.detach().requires_grad_(True)
           log_prob = self.log_prob_fn(q)
           grad = torch.autograd.grad(log_prob, q)[0]
           p = p + 0.5 * self.step_size * grad

           return q, p, log_prob

       def sample(self, q_init, n_samples=5000, warmup=1000):
           q = q_init.clone()
           samples = []
           log_probs = []
           n_accept = 0

           for i in range(n_samples + warmup):
               p = torch.randn_like(q)
               current_log_prob = self.log_prob_fn(
                   q.detach().requires_grad_(True))
               current_K = 0.5 * (p ** 2).sum()

               q_new, p_new, new_log_prob = self._leapfrog(q, p)
               new_K = 0.5 * (p_new ** 2).sum()

               log_accept = (new_log_prob - current_log_prob
                             + current_K - new_K)

               if torch.log(torch.rand(1)) < log_accept:
                   q = q_new.detach()
                   n_accept += 1

               if i >= warmup:
                   samples.append(q.clone())
                   log_probs.append(new_log_prob.item())

           accept_rate = n_accept / (n_samples + warmup)
           return torch.stack(samples), log_probs, accept_rate

For production use, we recommend using a NUTS implementation (e.g., from NumPyro
or PyTorch's ecosystem) that automatically tunes the step size during warmup:

.. code-block:: python

   def run_balance_wheel_hmc(model, observed_sfs, theta_L,
                              n_epochs_model=4, n_samples=5000,
                              warmup=1000, device='cuda'):
       """Full posterior inference pipeline."""
       n = observed_sfs.shape[0] + 1
       observed_sfs = observed_sfs.float().to(device)
       model.eval().to(device)

       init_sizes = torch.zeros(n_epochs_model, device=device)
       init_times = torch.linspace(-2, 2, n_epochs_model, device=device)
       q_init = torch.cat([init_sizes, init_times])

       def log_prob(params):
           return log_posterior(params, observed_sfs, model, theta_L)

       sampler = HMCSampler(log_prob, step_size=0.005, n_leapfrog=10)
       samples, log_probs, accept_rate = sampler.sample(
           q_init, n_samples=n_samples, warmup=warmup)

       half = n_epochs_model
       size_samples = torch.exp(samples[:, :half])
       time_samples = torch.exp(torch.sort(samples[:, half:])[0])

       return {
           'size_samples': size_samples.cpu(),
           'time_samples': time_samples.cpu(),
           'log_probs': log_probs,
           'accept_rate': accept_rate,
       }


What You Get from the Posterior
==================================

Credible intervals
--------------------

The 95% credible interval for each parameter is simply the 2.5th and 97.5th
percentiles of the posterior samples:

.. code-block:: python

   def credible_intervals(samples, level=0.95):
       alpha = (1 - level) / 2
       lower = torch.quantile(samples, alpha, dim=0)
       upper = torch.quantile(samples, 1 - alpha, dim=0)
       median = torch.quantile(samples, 0.5, dim=0)
       return {'median': median, 'lower': lower, 'upper': upper}

Posterior predictive checks
-----------------------------

For each posterior sample :math:`\Theta^{(s)}`, compute the expected SFS and
compare to the observed SFS. Systematic discrepancies indicate model
misspecification:

.. code-block:: python

   def posterior_predictive_check(model, samples, observed_sfs, theta_L):
       """Generate posterior predictive SFS distribution."""
       n = observed_sfs.shape[0] + 1
       predictive_sfs = []
       n_epochs = samples['size_samples'].shape[1]

       for i in range(min(1000, len(samples['size_samples']))):
           log_s = torch.log(samples['size_samples'][i]).unsqueeze(0)
           log_t = torch.log(samples['time_samples'][i]).unsqueeze(0)
           with torch.no_grad():
               pred = model(log_s, log_t, n, theta_L)
           predictive_sfs.append(pred.squeeze(0))

       predictive_sfs = torch.stack(predictive_sfs)
       return {
           'mean': predictive_sfs.mean(dim=0),
           'std': predictive_sfs.std(dim=0),
           'quantile_025': torch.quantile(predictive_sfs, 0.025, dim=0),
           'quantile_975': torch.quantile(predictive_sfs, 0.975, dim=0),
       }

If the observed SFS falls outside the 95% predictive interval for many
frequency classes, the model is likely misspecified -- the demographic model
cannot produce an SFS that looks like the data.

Model comparison
------------------

The marginal likelihood :math:`P(\mathbf{D} \mid \mathcal{M})` can be estimated
from the posterior samples using the harmonic mean estimator (crude but fast) or
bridge sampling (more reliable):

.. math::

   P(\mathbf{D} \mid \mathcal{M}) \approx \left[
   \frac{1}{S} \sum_{s=1}^{S} \frac{1}{P(\mathbf{D} \mid \Theta^{(s)})}
   \right]^{-1}

where :math:`\Theta^{(s)} \sim p(\Theta \mid \mathbf{D})` are posterior
samples. This enables Bayes factor comparison between, e.g., a two-epoch and a
three-epoch model -- a principled alternative to AIC that accounts for posterior
uncertainty.

.. admonition:: Marginal likelihood vs. AIC

   dadi and moments typically use AIC for model comparison:
   :math:`\text{AIC} = -2\ell(\hat{\Theta}) + 2k` where :math:`k` is the
   number of parameters. AIC is a frequentist approximation to the marginal
   likelihood. It works when the likelihood is approximately Gaussian near the
   MLE and the sample size is large relative to :math:`k`. For complex
   demographic models with correlated parameters and multimodal likelihoods,
   AIC can be misleading. The marginal likelihood, estimated from posterior
   samples, is more reliable but requires the posterior samples that only
   Balance Wheel can provide efficiently.


Comparison with Profile Likelihood
=====================================

Profile likelihood is what dadi and moments offer as an approximation to
posterior uncertainty. For a single parameter :math:`\Theta_i`, the profile
likelihood is:

.. math::

   \ell_{\text{profile}}(\Theta_i) = \max_{\Theta_{-i}} \ell(\Theta_i, \Theta_{-i})

where :math:`\Theta_{-i}` denotes all parameters except :math:`\Theta_i`. A
95% confidence interval is constructed by finding the values of :math:`\Theta_i`
where the profile likelihood drops by :math:`\chi^2_{1,0.95}/2 = 1.92` from the
maximum.

.. list-table:: Profile likelihood vs. full posterior
   :header-rows: 1
   :widths: 25 37 38

   * - Property
     - Profile likelihood (dadi/moments)
     - Full posterior (Balance Wheel)
   * - What it estimates
     - Confidence interval (frequentist)
     - Credible interval (Bayesian)
   * - Captures correlations
     - No (marginalizes by maximization)
     - Yes (joint posterior)
   * - Model comparison
     - AIC (point estimate)
     - Marginal likelihood (integrated)
   * - Computational cost
     - Moderate (grid search per parameter)
     - Low with Balance Wheel (~20 s)
   * - Handles multimodality
     - Poorly (finds local maximum)
     - Yes (HMC explores modes)
   * - Prior information
     - Not incorporated
     - Naturally incorporated

The full posterior is strictly more informative than the profile likelihood. It
provides everything the profile likelihood provides (marginal intervals) plus
joint distributions, correlations, predictive checks, and marginal likelihoods.
The only reason it was not used before is computational cost -- and Balance
Wheel removes that barrier.

.. admonition:: When profile likelihood is sufficient

   For well-identified models with approximately Gaussian posteriors (e.g., a
   two-population split-time model with large sample sizes), the profile
   likelihood and the posterior credible interval will agree closely. In this
   regime, the extra cost of HMC sampling may not be justified. Use Balance
   Wheel's full posterior when (1) parameters are correlated, (2) the posterior
   is multimodal or skewed, (3) you need model comparison, or (4) you want
   posterior predictive checks. Use profile likelihood (via moments directly)
   when the model is simple and you trust the Gaussian approximation.
