Rescaling

Even a well-tuned mechanism drifts; the final step is calibrating against the master clock.

After belief propagation (whether inside-outside or variational gamma), tsdate has posterior estimates for every node’s age. But these estimates assume a constant effective population size, which is almost never true. If the population was larger in the past, branches are too long; if it was smaller, they’re too short.

Rescaling corrects for this by comparing the inferred times against the empirical mutation clock: the observed number of mutations in different time windows. In the watch metaphor, this is the final calibration – checking the movement against a master reference clock and adjusting the hands until the ticks match.

This is the same idea used in SINGER’s ARG rescaling (see ARG Rescaling), adapted for tsdate’s continuous posteriors.

Prerequisites

This chapter assumes you have followed the full tsdate pipeline so far: the coalescent prior (The Coalescent Prior), the mutation likelihood (The Mutation Likelihood), and one of the two message passing algorithms (Inside-Outside Belief Propagation or Variational Gamma (Expectation Propagation)). Rescaling is a post-processing step applied to the node times that those algorithms produce.

The Problem: Mismatch Between Model and Reality

The coalescent prior assumes constant \(N_e\). Under this model, the expected time between coalescence events is fixed. But real populations have complex histories: bottlenecks, expansions, migrations.

When \(N_e\) was larger in the past:

  • Real coalescence events were slower (more time between them)

  • The constant-\(N_e\) model underestimates deep times

  • Branches in the deep past are too short

When \(N_e\) was smaller in the past:

  • Real coalescence events were faster

  • The model overestimates deep times

  • Branches in the deep past are too long

Either way, the mutation rate implied by the inferred times won’t match the true mutation rate. Rescaling fixes this.

Think of it as a watch whose mainspring weakens with age: the gears near the present run at the right speed, but the deeper you go, the more the rate drifts. Rescaling adjusts the time scale in each epoch so that the ticks (mutations) per unit time remain constant.

The Key Insight: Mutations Don’t Lie

Whatever the population history, the molecular clock still ticks at rate \(\mu\) per base pair per generation. The total number of mutations in a time window is:

\[\text{mutations in } [t_a, t_b) = \mu \cdot \text{(total branch length in } [t_a, t_b) \text{)}\]

If our estimated times are correct, the ratio \(\text{observed mutations} / \text{expected mutations}\) should be 1.0 in every time window. If it’s consistently \(> 1\) in some window, our branch lengths there are too short, and we need to stretch time. If it’s \(< 1\), we need to compress.

This ratio is a direct diagnostic: any deviation from 1.0 reveals how much the constant-\(N_e\) assumption has distorted the time scale in that epoch.

Step 1: Partition Time into Windows

Divide the time axis \([0, t_{\max})\) into \(J\) windows such that each window contains approximately equal total branch length:

\[[t_0 = 0, t_1), \quad [t_1, t_2), \quad \ldots, \quad [t_{J-1}, t_J)\]

Why equal branch length? So that each window has comparable statistical power for estimating the local mutation rate. A window with very little branch length would have very few mutations and a noisy estimate.

import numpy as np

def partition_time_axis(ts, node_times, J=1000):
    """Partition the time axis into J windows of roughly equal branch length.

    Parameters
    ----------
    ts : tskit.TreeSequence
    node_times : np.ndarray
        Current estimated node times (from EP or inside-outside).
    J : int
        Number of windows.

    Returns
    -------
    breakpoints : np.ndarray, shape (J+1,)
        Window boundaries: [breakpoints[j], breakpoints[j+1]) for window j.
    """
    # Collect all branch lengths weighted by span
    branch_data = []  # (midpoint_time, weighted_length)

    for edge in ts.edges():
        t_parent = node_times[edge.parent]
        t_child = node_times[edge.child]
        span = edge.right - edge.left       # genomic span in bp

        if t_parent > t_child:
            midpoint = (t_parent + t_child) / 2     # time midpoint of the edge
            weighted_length = span * (t_parent - t_child)  # bp * generations
            branch_data.append((midpoint, weighted_length))

    if not branch_data:
        return np.linspace(0, 1, J + 1)

    # Sort by time and find breakpoints with equal cumulative branch length
    branch_data.sort(key=lambda x: x[0])
    times = np.array([b[0] for b in branch_data])
    lengths = np.array([b[1] for b in branch_data])

    cum_length = np.cumsum(lengths)       # running total of branch length
    total_length = cum_length[-1]

    breakpoints = [0.0]
    target_per_window = total_length / J  # each window gets equal share

    for j in range(1, J):
        target = j * target_per_window
        idx = np.searchsorted(cum_length, target)  # find where cumulative exceeds target
        if idx < len(times):
            breakpoints.append(times[idx])
        else:
            breakpoints.append(times[-1])

    breakpoints.append(node_times.max() * 1.01)  # upper bound beyond all nodes
    return np.array(breakpoints)

Step 2: Count Mutations per Window

For each time window, count how many mutations fall in it. A mutation on edge \(e\) is assigned to the time window containing the midpoint of the edge (or, more precisely, proportionally distributed across windows that the edge spans).

def count_mutations_per_window(ts, node_times, breakpoints, mutation_rate):
    """Count observed and expected mutations in each time window.

    Parameters
    ----------
    ts : tskit.TreeSequence
    node_times : np.ndarray
    breakpoints : np.ndarray, shape (J+1,)
    mutation_rate : float

    Returns
    -------
    observed : np.ndarray, shape (J,)
        Mutation count in each window.
    expected : np.ndarray, shape (J,)
        Expected mutations (mu * total branch length) in each window.
    """
    J = len(breakpoints) - 1
    observed = np.zeros(J)
    expected = np.zeros(J)

    # Count mutations per edge (once)
    mut_per_edge = np.zeros(ts.num_edges, dtype=int)
    for mut in ts.mutations():
        if mut.edge >= 0:
            mut_per_edge[mut.edge] += 1

    for edge in ts.edges():
        t_parent = node_times[edge.parent]
        t_child = node_times[edge.child]
        span = edge.right - edge.left    # genomic span
        m_e = mut_per_edge[edge.id]      # observed mutations on this edge

        if t_parent <= t_child:
            continue  # skip edges with zero or negative branch length

        # Distribute this edge's contribution across windows
        for j in range(J):
            w_lo = breakpoints[j]
            w_hi = breakpoints[j + 1]

            # Overlap of edge [t_child, t_parent] with window [w_lo, w_hi]
            overlap_lo = max(t_child, w_lo)
            overlap_hi = min(t_parent, w_hi)

            if overlap_hi > overlap_lo:
                # Fraction of edge in this window
                frac = (overlap_hi - overlap_lo) / (t_parent - t_child)

                # Expected mutations: mu * span * overlap_length
                expected[j] += mutation_rate * span * (overlap_hi - overlap_lo)

                # Observed mutations: distribute proportionally by time overlap
                observed[j] += m_e * frac

    return observed, expected

Step 3: Compute Scaling Factors

For each window, the scaling factor is the ratio of observed to expected mutations:

\[s_j = \frac{\text{observed}_j}{\text{expected}_j}\]

If \(s_j > 1\), the branch lengths in window \(j\) are too short (need stretching). If \(s_j < 1\), they’re too long (need compression).

def compute_scaling_factors(observed, expected, min_count=1.0):
    """Compute per-window scaling factors.

    Parameters
    ----------
    observed, expected : np.ndarray, shape (J,)
    min_count : float
        Minimum mutation count to trust a window.

    Returns
    -------
    scales : np.ndarray, shape (J,)
    """
    scales = np.ones(len(observed))  # default: no rescaling

    for j in range(len(observed)):
        if expected[j] > 0 and observed[j] >= min_count:
            # Ratio > 1 means branches too short; < 1 means too long
            scales[j] = observed[j] / expected[j]

    return scales

Intuition: This is a piecewise-constant estimate of \(N_e(t)\). If the model assumed \(N_e = 10{,}000\) but the true \(N_e\) was \(20{,}000\) during window \(j\), then branch lengths are half what they should be, and \(s_j \approx 2\).

Probability Aside – Rescaling as implicit Ne estimation

The scaling factor \(s_j\) is closely related to the ratio of true \(N_e\) to assumed \(N_e\) in window \(j\). Under the coalescent with variable population size, the coalescent rate is \(1/N_e(t)\). If we used the wrong \(N_e\), the branch lengths in that epoch are stretched or compressed by the ratio \(N_e^{\text{true}} / N_e^{\text{assumed}}\). By setting \(s_j = \text{observed}_j / \text{expected}_j\), we effectively estimate this ratio and correct for it – without explicitly fitting a population-size model. This is similar in spirit to how PSMC (Timepiece I: PSMC) estimates \(N_e(t)\), except here it is a post-processing step rather than the main inference.

Step 4: Apply the Rescaling

Each node’s time is adjusted by the cumulative scaling factor up to its current time:

\[t_u^{\text{new}} = \int_0^{t_u^{\text{old}}} s(x) \, dx\]

where \(s(x)\) is the piecewise-constant scaling function. This integral is just a sum:

\[t_u^{\text{new}} = \sum_{j : t_j < t_u} s_j \cdot \min(t_u, t_{j+1}) - t_j) + s_{j^*} \cdot (t_u - t_{j^*})\]

where \(j^*\) is the window containing \(t_u\).

Calculus Aside – Piecewise integration

The rescaling integral \(\int_0^{t} s(x) \, dx\) with piecewise-constant \(s(x)\) decomposes into a sum of rectangles: in each window \([t_j, t_{j+1})\) where \(s(x) = s_j\), the contribution is \(s_j \cdot (t_{j+1} - t_j)\). For the final (partial) window containing \(t\), the contribution is \(s_{j^*} \cdot (t - t_{j^*})\). The result is a piecewise-linear, monotonically increasing function of the original time – a “warped” time axis that stretches or compresses different epochs.

def apply_rescaling(node_times, breakpoints, scales, fixed_nodes):
    """Apply piecewise rescaling to node times.

    Parameters
    ----------
    node_times : np.ndarray
        Current node times (will not be modified).
    breakpoints : np.ndarray, shape (J+1,)
    scales : np.ndarray, shape (J,)
    fixed_nodes : set
        Nodes whose times should not change (e.g., samples).

    Returns
    -------
    new_times : np.ndarray
        Rescaled node times.
    """
    new_times = np.zeros_like(node_times)
    J = len(scales)

    # Build cumulative scaling function
    # cum_rescaled[j] = rescaled time at window boundary j
    cum_rescaled = np.zeros(J + 1)
    for j in range(J):
        window_width = breakpoints[j + 1] - breakpoints[j]
        cum_rescaled[j + 1] = cum_rescaled[j] + scales[j] * window_width

    for u in range(len(node_times)):
        if u in fixed_nodes:
            new_times[u] = node_times[u]  # samples stay fixed
            continue

        t = node_times[u]

        # Find which window t falls in
        j = np.searchsorted(breakpoints, t, side='right') - 1
        j = min(j, J - 1)
        j = max(j, 0)

        # Rescaled time = cumulative up to window j + fraction within window
        fraction_in_window = t - breakpoints[j]
        new_times[u] = cum_rescaled[j] + scales[j] * fraction_in_window

    return new_times

Iterating the Rescaling

A single round of rescaling may not be sufficient because the window boundaries depend on the node times, which change after rescaling. tsdate iterates:

  1. Compute window boundaries from current times

  2. Count mutations per window

  3. Compute scaling factors

  4. Apply rescaling to get new times

  5. Repeat (default: 5 iterations)

Each iteration refines the time scale, like adjusting a regulator screw on a mechanical watch – small turns that progressively bring the rate into alignment with the master clock.

def iterative_rescaling(ts, node_times, mutation_rate, fixed_nodes,
                        J=1000, num_iter=5):
    """Iteratively rescale node times to match the mutation clock.

    Parameters
    ----------
    ts : tskit.TreeSequence
    node_times : np.ndarray
    mutation_rate : float
    fixed_nodes : set
    J : int
        Number of time windows.
    num_iter : int
        Number of rescaling iterations.

    Returns
    -------
    node_times : np.ndarray
        Rescaled node times.
    """
    times = node_times.copy()

    for iteration in range(num_iter):
        # 1. Partition time into equal-branch-length windows
        breakpoints = partition_time_axis(ts, times, J)

        # 2. Count observed vs. expected mutations per window
        observed, expected = count_mutations_per_window(
            ts, times, breakpoints, mutation_rate)

        # 3. Compute scaling factors (observed / expected)
        scales = compute_scaling_factors(observed, expected)

        # 4. Apply piecewise rescaling
        times = apply_rescaling(times, breakpoints, scales, fixed_nodes)

    return times

Connection to Population Size History

The scaling factors \(s_j\) are intimately related to the effective population size history. Under the coalescent with variable \(N_e(t)\):

  • The rate of coalescence at time \(t\) is \(1 / N_e(t)\)

  • The mutation rate is constant at \(\mu\)

If we model the coalescent under constant \(N_e^{(0)}\) but the true population size in window \(j\) is \(N_e^{(j)}\), then:

\[s_j \approx \frac{N_e^{(j)}}{N_e^{(0)}}\]

So the rescaling implicitly estimates the population size history. This is similar to what PSMC does (see the Timepiece I: PSMC), but here it’s a post-processing step rather than the main inference.

Edge Cases and Robustness

Several practical issues arise:

Windows with few mutations: If a window has very few mutations (or none), the scaling factor is unreliable. tsdate handles this by:

  • Setting a minimum count threshold

  • Smoothing adjacent scaling factors

  • Falling back to a scale of 1.0 for empty windows

Negative branch lengths: After rescaling, some edges might end up with the parent younger than the child. tsdate enforces constraints by adjusting times to maintain the topological ordering.

Convergence: Rescaling typically converges within 3-5 iterations. The scaling factors stabilize as the times settle into their correct positions.

The Full Pipeline

Putting rescaling together with EP, here is the complete tsdate pipeline from raw tree sequence to dated genealogy.

def tsdate_full_pipeline(ts, mutation_rate, Ne=1.0, max_ep_iter=25,
                         rescaling_intervals=1000, rescaling_iterations=5):
    """The complete tsdate pipeline.

    Parameters
    ----------
    ts : tskit.TreeSequence
        Input (topology from tsinfer).
    mutation_rate : float
    Ne : float
    max_ep_iter : int
    rescaling_intervals : int
    rescaling_iterations : int

    Returns
    -------
    dated_ts : np.ndarray
        Posterior mean node times.
    """
    # Step 1: Build priors (Gear 1 -- the expected beat rate)
    prior_grid = build_coalescent_priors(ts, Ne)

    # Step 2: Run EP (Gear 4 -- messages flow through the gear train)
    posteriors = run_ep(ts, mutation_rate, prior_grid, max_ep_iter)

    # Step 3: Extract posterior means
    node_times = np.zeros(ts.num_nodes)
    for u in range(ts.num_nodes):
        if u in posteriors:
            node_times[u] = posteriors[u].mean

    fixed_nodes = set(ts.samples())
    for s in fixed_nodes:
        node_times[s] = 0.0

    # Step 4: Rescale (Gear 5 -- calibrate against the mutation clock)
    if rescaling_iterations > 0:
        node_times = iterative_rescaling(
            ts, node_times, mutation_rate, fixed_nodes,
            J=rescaling_intervals,
            num_iter=rescaling_iterations
        )

    return node_times

Summary

Rescaling is tsdate’s final calibration step – the last gear in the mechanism:

  1. Partition the time axis into \(J\) windows of equal branch length

  2. Count observed vs. expected mutations per window

  3. Scale each window by the ratio \(s_j = \text{observed}/\text{expected}\)

  4. Apply the piecewise scaling to all node times

  5. Iterate until convergence (default: 5 rounds)

The key equation:

\[t_u^{\text{new}} = \int_0^{t_u^{\text{old}}} \frac{\text{observed mutations}(x)} {\text{expected mutations}(x)} \, dx\]

This corrects for variable population size without explicitly modeling it, by letting the molecular clock be the final arbiter of time. In the watch metaphor, the mutation clock is the master reference – it ticks at a known rate (\(\mu\)) regardless of population history, and rescaling adjusts every hand on the dial until the ticks match.

Congratulations – you’ve now built every gear of the tsdate mechanism:

  1. Coalescent prior – the expected beat rate from coalescent theory: informed starting beliefs about node ages

  2. Mutation likelihood – evidence from the mutation clock: the Poisson model connecting observed data to branch lengths

  3. Inside-outside – messages flowing through the gear train on a discrete grid

  4. Variational gamma – the same messages, now carried by continuous gamma distributions via expectation propagation

  5. Rescaling – calibrating the clock against the master reference: adjusting for variable population size

Together, these gears transform a topology-only tree sequence (from tsinfer) into a fully dated genealogy. You understand the math, the code, and the intuition behind every step. The watch is assembled, calibrated, and keeping time.