Lab 04 — Centroid Decomposition

Goal

Implement centroid decomposition for efficient tree queries — counting / aggregating over all paths in a tree in O(N log N) or O(N log² N).

Background

The centroid of a tree is a vertex whose removal leaves no subtree with more than N/2 vertices. Every tree has a centroid (sometimes two).

Centroid decomposition: recursively decompose the tree:

  1. Find centroid; process all paths passing through it
  2. Remove centroid; recurse on each remaining subtree

Recursion depth: O(log N) (each level halves subtree size). Total work per level: O(N) typically → O(N log N) total.

Originally developed for tree DP and offline path queries. Powerful technique for problems of the form: “count/sum over all pairs (u, v) in a tree with property P on the u-v path.”

Interview Context

Almost exclusively ICPC. Some appearances in:

  • Quant algo research on tree models
  • Phylogenetic inference (computational biology)
  • A handful of compiler dominator-tree analyses

Industry interviews: near-zero.

When to Skip This Topic

Skip if any of these are true:

  • You are not training for ICPC or competitive contests
  • You haven’t done Lab 03 (HLD) — these are sibling techniques
  • You don’t have 2+ weeks for the implementation practice

Centroid decomposition has a high “first implementation” cost. Don’t attempt without serious tree-DP fluency.

Problem Statement

Count Paths in Tree with Length ≤ K.

Given a tree of N vertices, edge weights w_e, and integer K, count the number of unordered pairs (u, v) such that the sum of edge weights on the path from u to v is ≤ K.

Constraints

  • 1 ≤ N ≤ 5×10^4
  • 1 ≤ K ≤ 10^9
  • 1 ≤ w_e ≤ 10^4

Clarifying Questions

  1. Are weights positive? (Yes — required for the standard algorithm.)
  2. Count ordered or unordered pairs? (Unordered, exclude self-pairs.)
  3. Are edge weights integers? (Yes — convenient for sort/binary-search.)

Examples

Tree: 1-2 (w=2), 2-3 (w=1), 2-4 (w=3)
K = 4
Paths and lengths:
  (1,2): 2 ✓
  (1,3): 3 ✓
  (1,4): 5 ✗
  (2,3): 1 ✓
  (2,4): 3 ✓
  (3,4): 4 ✓
Answer: 5

Brute Force

For each unordered pair (u, v), compute path length (LCA + ancestor distances). O(N² log N).

For N = 5×10^4: 2.5×10^9 ops — TLE.

Brute Force Complexity

  • Time: O(N² log N) for path length per pair
  • Space: O(N) plus LCA tables

Optimization Path

Centroid decomposition shines for “paths through centroid” enumeration:

  • A path between u and v either passes through the centroid c or lies entirely in one subtree (after c is removed)
  • Paths through c: count by aggregating distances from c to every other vertex
  • Paths in subtrees: handled recursively

Per centroid:

  1. BFS from c, recording dist(c, v) for every v in c’s connected component
  2. Sort distances per subtree
  3. Count pairs (u, v) with dist(c, u) + dist(c, v) ≤ K using two pointers
  4. Subtract pairs where u and v are in the same subtree (they would have been counted as paths through some other centroid)

Final Expected Approach

def centroid_decompose(root, K):
    n = len(adj)
    removed = [False] * n
    size = [0] * n
    total = 0
    
    def calc_size(u, parent):
        size[u] = 1
        for v, _ in adj[u]:
            if v != parent and not removed[v]:
                calc_size(v, u)
                size[u] += size[v]
    
    def find_centroid(u, parent, tree_size):
        for v, _ in adj[u]:
            if v != parent and not removed[v] and size[v] > tree_size // 2:
                return find_centroid(v, u, tree_size)
        return u
    
    def gather_dists(u, parent, d, out):
        out.append(d)
        for v, w in adj[u]:
            if v != parent and not removed[v]:
                gather_dists(v, u, d + w, out)
    
    def count_pairs(dists, K):
        dists.sort()
        i, j = 0, len(dists) - 1
        c = 0
        while i < j:
            if dists[i] + dists[j] <= K:
                c += j - i
                i += 1
            else:
                j -= 1
        return c
    
    def decompose(u):
        nonlocal total
        calc_size(u, -1)
        c = find_centroid(u, -1, size[u])
        all_dists = [0]
        for v, w in adj[c]:
            if not removed[v]:
                sub = []
                gather_dists(v, c, w, sub)
                # subtract pairs within this subtree
                total -= count_pairs(sub[:], K)
                all_dists.extend(sub)
        total += count_pairs(all_dists, K)
        removed[c] = True
        for v, _ in adj[c]:
            if not removed[v]:
                decompose(v)
    
    decompose(root)
    return total

Data Structures

  • Adjacency list (vertex → list of (neighbor, weight))
  • removed[v]: marks centroids removed from active tree
  • size[v]: subtree size in current decomposition step
  • Distance lists per subtree

Correctness Argument

  • Centroid existence: every tree has a centroid (induction on tree structure).
  • Recursion depth O(log N): removing centroid leaves subtrees of size ≤ N/2.
  • Pair counting via subtraction: a path (u, v) is counted exactly once — at the deepest centroid c that lies on path(u, v). The inclusion-exclusion (add all-vertices count, subtract per-subtree count) ensures each path-through-c is counted once.
  • Two pointers for sum ≤ K: standard.

Complexity

  • Time: O(N log² N) — O(log N) levels, O(N log N) per level (sort dominates)
  • Space: O(N) for tree + O(N) for decomposition state

Implementation Requirements

  • Iterative or carefully bounded recursive DFS (Python: 5×10^4 may need increased limit)
  • Recompute size[] for each subtree (in the recursive call); critical bug source
  • Two-pointer pair counting requires sorted distances
  • The inclusion-exclusion trick is the conceptual core; verify on small cases

Tests

  • Linear chain (path graph): N(N-1)/2 paths; verify against brute force
  • Star tree: each pair sum is at most 2*max_weight
  • Balanced binary tree
  • N = 1 (no pairs)
  • K = 0 with positive weights (only self-pairs; answer = 0)
  • Very large K (all pairs counted)

Follow-up Questions

  • Count paths with length exactly K. → Use hashmap of distances per subtree; sum complement counts.
  • Sum of path lengths (rather than count). → Aggregate sums in addition to counts during two-pointer scan.
  • XOR of edge weights instead of sum, equals K. → Replace sort/two-pointer with XOR trie.
  • Online (tree mutating). → Much harder; use top trees or Euler-tour trees.
  • K-th shortest path. → Different problem; rarely tractable on trees with centroid.

Product Extension

  • Phylogenetics: counting pairs of species within evolutionary distance K
  • Network distance queries on hierarchical trees
  • Distance-based recommendation systems on tree-like ontologies

Language/Runtime Follow-ups

  • Python: sort + two-pointer per level; constant factor is the killer. C++ recommended for N ≥ 10^4.
  • C++: standard ICPC implementation; ~100 lines.
  • Recursive DFS: centroid decomposition depth O(log N), but inner DFS depth O(N) — limit must accommodate.

Common Bugs

  1. Forgot to recompute size[] for each subtree. Sizes from before removal are stale.
  2. Centroid finder doesn’t follow the right child. Must descend toward the largest remaining subtree.
  3. removed[v] check forgotten in DFS: revisits removed centroids.
  4. Off-by-one in pair counting (counting self-pair). Handle separately.
  5. Inclusion-exclusion wrong sign. Add all, subtract per-subtree.
  6. Stack overflow on deep recursion. Convert inner DFS to iterative.

Debugging Strategy

  • For small N, compare against brute force at each level
  • Log the centroid chosen at each call
  • Verify subtree sizes recomputed correctly (print before find_centroid)
  • For two-pointer: print sorted distances and the (i, j) cursor trajectory

Mastery Criteria

  • Implement centroid decomposition in ≤ 60 min from memory
  • Explain the inclusion-exclusion trick for path counting
  • Identify problems amenable to centroid decomposition (offline path queries on static tree)
  • Distinguish from HLD: HLD is online with edge updates; centroid is offline/path-counting
  • State complexity precisely: O(N log² N) typical