Memory-Efficient Viterbi Inference
The classical machine works, but it does not fit in the case. Redesign it.
The second step in the Threads pipeline takes the candidate matches from the PBWT pre-filter and runs a memory-efficient Viterbi algorithm under the Li-Stephens model to find the optimal threading path for each sample.
The Classical Viterbi Limitation
Given a reference panel \(H\) of \(N\) haplotypes over \(M\) sites and a query haplotype \(g\), the Li-Stephens model assigns a probability \(P(\pi)\) to each path \(\pi \in \{1, \ldots, N\}^M\) through the panel. A Viterbi path is a path of maximum probability.
The classical Viterbi algorithm finds this path in \(O(NM)\) time by constructing a full \(N \times M\) probability matrix, then performing a traceback. For biobank-scale data, this matrix is computationally prohibitive – even after PBWT pre-filtering reduces \(N\) to \(L\), the memory requirement for long genomic tracts remains high.
Two Key Observations
The Threads-Viterbi algorithm exploits two properties of the Li-Stephens model:
Observation 1: Recombination events are rare. Viterbi paths consist of few but long segments. In the 1000 Genomes Project (2,251 sequences), the average segment length exceeds 200 kilobases. In UK Biobank array data (\(N = 337{,}464\)), segments average well over a megabase. This means a complete Viterbi path can be stored compactly by recording only the segment breakpoints and threading targets.
Observation 2: Recombination is symmetric. Under the Li-Stephens model:
for any states \(\alpha, \beta\). This symmetry dramatically reduces the search space for possible Viterbi paths, as formalized in Proposition 1.
Proposition 1
Proposition 1. Suppose \(\pi^{(i)}\) is a Viterbi path for the subset of the panel \(H\) containing sites 1 through \(i\). If there exists a Viterbi path through \(H\) that recombines between sites \(i\) and \(i + 1\), then there exists a Viterbi path \(\pi\) through \(H\) satisfying \(\pi_i = \pi^{(i)}_i\).
In plain language: at site \(i\), we only need to consider recombination from the sequence of highest probability given all observations up to \(i\). This is the property that makes the branch-and-bound strategy correct.
The Segment Set
The algorithm maintains a set \(\Omega\) of path segments, each consisting of:
A start site \(m_\omega \in \{1, \ldots, M\}\)
A threading target \(n_\omega \in \{1, \ldots, N\}\)
If \(m_\omega > 0\), a traceback segment \(\omega' \in \Omega\) with \(m_{\omega'} < m_\omega\)
A full path through \(H\) is constructed by starting at any segment and following traceback pointers until a segment with \(m_\omega = 0\) is reached. The penalty (negative log-likelihood) of a path ending at \(\omega\) is denoted \(s(\omega)\).
The set \(\Omega\) contains exactly \(N\) active segments \(\omega_1, \ldots, \omega_N\), one per reference haplotype. The segment set is complete if each \(P(\omega_n)\) is the Li-Stephens-optimal path ending at haplotype \(n\). When the set is complete, the active segment with minimum penalty gives a Viterbi path.
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class Segment:
"""A path segment in the Threads-Viterbi algorithm.
Parameters
----------
start : int
Start site of this segment.
target : int
Reference haplotype index this segment copies from.
penalty : float
Cumulative negative log-likelihood up to this segment.
parent : Segment or None
Traceback pointer to the previous segment.
"""
start: int
target: int
penalty: float
parent: Optional['Segment'] = field(default=None, repr=False)
def traceback(self):
"""Follow traceback pointers to reconstruct the full path."""
path = []
seg = self
while seg is not None:
path.append((seg.start, seg.target))
seg = seg.parent
return list(reversed(path))
# Demonstrate the segment structure
seg0 = Segment(start=0, target=3, penalty=10.0)
seg1 = Segment(start=500, target=7, penalty=15.0, parent=seg0)
seg2 = Segment(start=1200, target=3, penalty=22.0, parent=seg1)
path = seg2.traceback()
print("Viterbi path segments:")
for start, target in path:
print(f" Site {start}: copy from haplotype {target}")
print(f"Total penalty: {seg2.penalty:.1f}")
The Branch Step (Theorem 1)
Given a complete segment set \(\Omega^{(m)}\) for sites \(1\) through \(m\), the branch step constructs \(\Omega^{(m+1)}\) that is complete for sites \(1\) through \(m + 1\).
Let \(\omega'\) be the active segment with minimum penalty (the current Viterbi path), and let \(\rho\) and \(\rho_c\) be the penalties for recombination and no-recombination respectively.
For each active segment \(\omega_n\), if continuing without recombination is worse than recombining from the best path:
then a new segment \(\omega(m+1, n, \omega')\) is created, representing a recombination from the best path to haplotype \(n\) at site \(m+1\).
The new active segment for haplotype \(n\) becomes whichever option has lower penalty:
where \(\mu_n\) is the match/mismatch penalty at site \(m + 1\).
Complexity per site: The branch step adds at most \(N\) new segments, giving \(O(NM)\) total segments across all sites. In practice, new segments are created only at inferred recombination events, which are rare.
def branch_step(active_segments, query_allele, ref_alleles,
rho_penalty, rho_c_penalty, mismatch_penalty):
"""Perform the branch step at one site.
For each active segment, decide whether continuing without
recombination is better than recombining from the best path.
Parameters
----------
active_segments : list of Segment
One active segment per reference haplotype.
query_allele : int
Query haplotype's allele at this site (0 or 1).
ref_alleles : ndarray, shape (N,)
Reference panel alleles at this site.
rho_penalty : float
Penalty for recombination (-log(rho/N)).
rho_c_penalty : float
Penalty for no recombination (-log(1 - rho)).
mismatch_penalty : float
Penalty for allele mismatch.
Returns
-------
new_active : list of Segment
Updated active segments for the next site.
n_new_segments : int
Number of new segments created (recombination events).
"""
N = len(active_segments)
# Find the best current path (minimum penalty)
best = min(active_segments, key=lambda s: s.penalty)
site = active_segments[0].start + 1 # next site
new_active = []
n_new = 0
for n in range(N):
seg = active_segments[n]
# Emission penalty: mismatch if alleles differ
mu_n = mismatch_penalty if ref_alleles[n] != query_allele else 0.0
# Cost of continuing vs. recombining
cost_continue = seg.penalty + rho_c_penalty + mu_n
cost_recombine = best.penalty + rho_penalty + mu_n
if cost_recombine < cost_continue:
# Create a new segment (recombination event)
new_seg = Segment(
start=site, target=n,
penalty=cost_recombine, parent=best
)
new_active.append(new_seg)
n_new += 1
else:
# Continue the existing segment
seg_updated = Segment(
start=seg.start, target=n,
penalty=cost_continue, parent=seg.parent
)
new_active.append(seg_updated)
return new_active, n_new
# Demonstrate on a small example
N = 4
ref = np.array([[0,0,1,0,0,1,0,0],
[0,0,0,0,1,1,0,0],
[0,1,1,0,0,0,0,1],
[0,0,1,0,0,1,1,0]])
query = np.array([0,0,1,0,0,1,0,1])
active = [Segment(start=0, target=n, penalty=0.0) for n in range(N)]
total_new = 0
for site in range(1, len(query)):
active, n_new = branch_step(
active, query[site], ref[:, site],
rho_penalty=5.0, rho_c_penalty=0.01, mismatch_penalty=3.0
)
total_new += n_new
best = min(active, key=lambda s: s.penalty)
path = best.traceback()
print(f"Query: {query.tolist()}")
print(f"Viterbi path ({total_new} recombination events):")
for start, target in path:
print(f" From site {start}: copy haplotype {target}")
print(f"Final penalty: {best.penalty:.2f}")
The Bound Step (Theorem 2)
The bound step prunes the segment set without losing completeness.
Theorem 2. Let \(\Omega\) be a complete segment set with active segments \(\omega_1, \ldots, \omega_N\). Define \(\Omega^* \subseteq \Omega\) as the union of all traceback paths from the active segments: \(\Omega^* = \bigcup_{n=1}^{N} P(\omega_n)\). Then \(\Omega^*\) is also a complete segment set.
The bound step simply discards any segment that is not on a traceback path from an active segment – these segments are undercut and can never be part of an optimal path.
Threads applies the bound step at regular intervals using a heuristic threshold:
Initialize with \(B_0 = 10 \cdot N\)
If \(|\Omega^{(m)}| > B_0\), prune to \(\Omega^*\)
If the next pruning trigger occurs within 30 sites, double the threshold to \(B_1 = 2B_0\); otherwise reset to \(B_0\)
This balances pruning frequency against the risk of memory spikes from rapid segment accumulation.
def bound_step(all_segments, active_segments):
"""Prune segments not on any active traceback path.
Parameters
----------
all_segments : list of Segment
All segments accumulated so far.
active_segments : list of Segment
Currently active segments (one per haplotype).
Returns
-------
pruned : list of Segment
Only segments reachable from active traceback paths.
"""
# Collect all segments on traceback paths from active segments
reachable = set()
for seg in active_segments:
current = seg
while current is not None and id(current) not in reachable:
reachable.add(id(current))
current = current.parent
pruned = [s for s in all_segments if id(s) in reachable]
return pruned
# Demonstrate pruning
print(f"\nBound step example:")
print(f" Before pruning: {len(active) + total_new} segments")
pruned = bound_step(active, active)
print(f" After pruning: {len(pruned)} segments")
print(" (Undercut segments are discarded)")
Traceback
After processing all \(M\) sites, the final Viterbi path is recovered by:
Identifying the active segment \(\omega^*\) with minimum penalty
Following traceback pointers from \(\omega^*\) until reaching a segment with start site 0
The result is a piecewise-constant path through the reference panel: a sequence of segments, each specifying a threading target and a genomic interval.
Parallelism
A critical property of the Threads-Viterbi algorithm: the \(N\) Viterbi instances (one per sample) are completely independent. The output of each HMM does not depend on any other. This means all \(N\) instances can run in parallel, divided evenly among available CPU cores.
Given \(L\) haplotype matches per sample and \(N_{\text{CPU}}\) cores:
Time: \(O(MLN / N_{\text{CPU}})\)
Memory: \(O(LN)\) average
The genotype data is streamed from disk once per core, and neither the full genotype matrix nor any \(N \times M\) probability matrix is ever stored in memory.
Complexity Summary
Property |
Classical Viterbi |
Threads-Viterbi |
|---|---|---|
Time (per sample) |
\(O(NM)\) |
\(O(NM)\) (same) |
Memory (per sample) |
\(O(NM)\) |
\(O(N)\) average |
Total (all samples) |
\(O(MN^2)\) time + memory |
\(O(MLN/N_{\text{CPU}})\) time, \(O(LN)\) memory |
Parallelism |
Not straightforward |
Embarrassingly parallel |