"""
LC 235 — Lowest Common Ancestor of a Binary Search Tree.

Four implementations:
    1. lca_recursive          — recursive split-point descent (O(H), O(H) stack).
    2. lca_iterative          — iterative split-point descent (O(H) time, O(1) space — ship this).
    3. lca_general_recursive  — LC 236 general-tree LCA (post-order, 4-case) for follow-up + cross-check.
    4. lca_path_intersect     — O(N) brute force: find both paths, return last common.

INVARIANT (BST LCA shortcut): the LCA is the first node on the search path
from the root where p and q diverge — i.e., the first node N where
p.val and q.val are NOT both < N.val and NOT both > N.val.
"""

from __future__ import annotations

import random
import sys
from typing import Optional

sys.setrecursionlimit(20_000)


class TreeNode:
    __slots__ = ("val", "left", "right")

    def __init__(self, val: int = 0, left: Optional["TreeNode"] = None,
                 right: Optional["TreeNode"] = None):
        self.val = val
        self.left = left
        self.right = right


# ----------------------------------------------------------------------
# Solution A — BST recursive descent.
# ----------------------------------------------------------------------

def lca_recursive(root: TreeNode, p: TreeNode, q: TreeNode) -> TreeNode:
    node = root
    # Invariant: if LCA exists in this subtree, this descent will find it.
    if p.val < node.val and q.val < node.val:
        return lca_recursive(node.left, p, q)  # type: ignore[arg-type]
    if p.val > node.val and q.val > node.val:
        return lca_recursive(node.right, p, q)  # type: ignore[arg-type]
    return node


# ----------------------------------------------------------------------
# Solution B — BST iterative descent (O(1) space — production form).
# ----------------------------------------------------------------------

def lca_iterative(root: TreeNode, p: TreeNode, q: TreeNode) -> TreeNode:
    node = root
    while node is not None:
        if p.val < node.val and q.val < node.val:
            node = node.left  # type: ignore[assignment]
        elif p.val > node.val and q.val > node.val:
            node = node.right  # type: ignore[assignment]
        else:
            return node
    raise ValueError("LCA not found — precondition violated (p or q missing).")


# ----------------------------------------------------------------------
# Solution C — General binary tree LCA (LC 236), works regardless of BST property.
# ----------------------------------------------------------------------

def lca_general_recursive(root: Optional[TreeNode], p: TreeNode, q: TreeNode) -> Optional[TreeNode]:
    if root is None or root is p or root is q:
        return root
    left = lca_general_recursive(root.left, p, q)
    right = lca_general_recursive(root.right, p, q)
    if left is not None and right is not None:
        return root
    return left if left is not None else right


# ----------------------------------------------------------------------
# Solution D — Path intersection brute force (oracle).
# ----------------------------------------------------------------------

def lca_path_intersect(root: TreeNode, p: TreeNode, q: TreeNode) -> TreeNode:
    def path_to(target: TreeNode) -> list[TreeNode]:
        # DFS that records the path; works for any binary tree, not just BST.
        path: list[TreeNode] = []

        def go(node: Optional[TreeNode]) -> bool:
            if node is None:
                return False
            path.append(node)
            if node is target:
                return True
            if go(node.left) or go(node.right):
                return True
            path.pop()
            return False

        go(root)
        return path

    path_p = path_to(p)
    path_q = path_to(q)
    last_common: Optional[TreeNode] = None
    for a, b in zip(path_p, path_q):
        if a is b:
            last_common = a
        else:
            break
    assert last_common is not None, "Path intersection broke — bug in oracle."
    return last_common


# ----------------------------------------------------------------------
# Helpers.
# ----------------------------------------------------------------------

def insert_bst(root: Optional[TreeNode], v: int) -> TreeNode:
    if root is None:
        return TreeNode(v)
    if v < root.val:
        root.left = insert_bst(root.left, v)
    elif v > root.val:
        root.right = insert_bst(root.right, v)
    return root


def build_random_bst(rng: random.Random, n: int) -> tuple[Optional[TreeNode], list[TreeNode]]:
    """Returns (root, list_of_all_nodes) for picking random p, q."""
    if n == 0:
        return None, []
    values = rng.sample(range(-50_000, 50_000), n)
    root: Optional[TreeNode] = None
    for v in values:
        root = insert_bst(root, v)
    # Collect all nodes via DFS.
    nodes: list[TreeNode] = []
    stack: list[TreeNode] = [root]  # type: ignore[list-item]
    while stack:
        n_ = stack.pop()
        nodes.append(n_)
        if n_.right:
            stack.append(n_.right)
        if n_.left:
            stack.append(n_.left)
    return root, nodes


# ----------------------------------------------------------------------
# Stress test.
# ----------------------------------------------------------------------

def stress_test() -> None:
    rng = random.Random(42)
    n_iter = 1000
    for _ in range(n_iter):
        size = rng.randint(2, 60)
        root, nodes = build_random_bst(rng, size)
        assert root is not None
        p, q = rng.sample(nodes, 2)
        a = lca_recursive(root, p, q)
        b = lca_iterative(root, p, q)
        c = lca_general_recursive(root, p, q)
        d = lca_path_intersect(root, p, q)
        assert a is b is c is d, (
            f"Disagreement: rec={a.val}, iter={b.val}, "
            f"general={c.val if c else None}, brute={d.val} for p={p.val}, q={q.val}"
        )
    print(f"stress_test: {n_iter} random BSTs — all 4 variants agree on LCA.")


# ----------------------------------------------------------------------
# Edge cases.
# ----------------------------------------------------------------------

def edge_cases() -> None:
    # The canonical LC example: [6,2,8,0,4,7,9,null,null,3,5].
    root = TreeNode(6)
    root.left = TreeNode(2, TreeNode(0), TreeNode(4, TreeNode(3), TreeNode(5)))
    root.right = TreeNode(8, TreeNode(7), TreeNode(9))
    p = root.left          # 2
    q = root.right         # 8
    assert lca_recursive(root, p, q) is root
    assert lca_iterative(root, p, q) is root
    assert lca_general_recursive(root, p, q) is root
    assert lca_path_intersect(root, p, q) is root

    # p ancestor of q: p=2, q=4 → LCA is 2.
    p = root.left              # 2
    q = root.left.right        # 4
    for fn in (lca_recursive, lca_iterative, lca_general_recursive, lca_path_intersect):
        assert fn(root, p, q) is root.left, f"{fn.__name__} failed p-ancestor case"

    # Deep split: p=3, q=5 → LCA is 4.
    p = root.left.right.left   # 3
    q = root.left.right.right  # 5
    for fn in (lca_recursive, lca_iterative, lca_general_recursive, lca_path_intersect):
        assert fn(root, p, q) is root.left.right, f"{fn.__name__} failed deep-split case"

    # Two-node tree: [1, null, 2].
    r2 = TreeNode(1, None, TreeNode(2))
    p2, q2 = r2, r2.right
    for fn in (lca_recursive, lca_iterative, lca_general_recursive, lca_path_intersect):
        assert fn(r2, p2, q2) is r2  # type: ignore[arg-type]

    # Right-skewed BST 1->2->3->4->5; LCA(2, 5) should be 2.
    rs: Optional[TreeNode] = None
    for v in [1, 2, 3, 4, 5]:
        rs = insert_bst(rs, v)
    assert rs is not None
    # Locate nodes 2 and 5 by walking.
    n2 = rs.right
    n5 = rs.right.right.right.right  # type: ignore[union-attr]
    for fn in (lca_recursive, lca_iterative, lca_general_recursive, lca_path_intersect):
        assert fn(rs, n2, n5) is n2, f"{fn.__name__} failed right-skewed case"  # type: ignore[arg-type]

    print("edge_cases: PASS")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
