"""
LC 230 — Kth Smallest Element in a BST.

Three base implementations:
    1. kth_smallest_iterative           — iterative inorder + counter early termination (ship this).
    2. kth_smallest_recursive           — recursive inorder with nonlocal counter + early flag.
    3. kth_smallest_brute_full_inorder  — full inorder into list; oracle.

Plus the augmented-BST follow-up:
    class OrderStatBST — every node carries `size`; supports insert / delete / kth_smallest / rank
    all in O(H).  Cross-checked against a sorted-list oracle on randomized insert/delete sequences.

INVARIANT (base): inorder traversal of a BST yields values in strictly ascending order;
the k-th value visited (1-indexed) is the k-th smallest.

INVARIANT (augmented): for every node, node.size == 1 + size(left) + size(right),
where size(None) := 0.  Must be maintained on every insert and delete.
"""

from __future__ import annotations

import random
import sys
from bisect import insort
from typing import Optional

sys.setrecursionlimit(20_000)


class TreeNode:
    __slots__ = ("val", "left", "right")

    def __init__(self, val: int = 0, left: Optional["TreeNode"] = None,
                 right: Optional["TreeNode"] = None):
        self.val = val
        self.left = left
        self.right = right


# ----------------------------------------------------------------------
# Solution A — iterative inorder with k-counter (early terminating).
# ----------------------------------------------------------------------

def kth_smallest_iterative(root: Optional[TreeNode], k: int) -> int:
    stack: list[TreeNode] = []
    cur = root
    while cur is not None or stack:
        while cur is not None:
            stack.append(cur)
            cur = cur.left
        node = stack.pop()
        k -= 1
        if k == 0:
            return node.val
        cur = node.right
    raise ValueError("k out of range — fewer than k nodes in tree.")


# ----------------------------------------------------------------------
# Solution B — recursive inorder with closure counter + early flag.
# ----------------------------------------------------------------------

def kth_smallest_recursive(root: Optional[TreeNode], k: int) -> int:
    state = {"k": k, "result": None}  # mutable container avoids nonlocal-rebind issue

    def go(node: Optional[TreeNode]) -> None:
        if node is None or state["result"] is not None:
            return
        go(node.left)
        if state["result"] is not None:
            return
        state["k"] -= 1
        if state["k"] == 0:
            state["result"] = node.val
            return
        go(node.right)

    go(root)
    if state["result"] is None:
        raise ValueError("k out of range — fewer than k nodes in tree.")
    return state["result"]  # type: ignore[return-value]


# ----------------------------------------------------------------------
# Solution C — brute force: full inorder into list (oracle).
# ----------------------------------------------------------------------

def kth_smallest_brute_full_inorder(root: Optional[TreeNode], k: int) -> int:
    out: list[int] = []

    def go(node: Optional[TreeNode]) -> None:
        if node is None:
            return
        go(node.left)
        out.append(node.val)
        go(node.right)

    go(root)
    return out[k - 1]


# ----------------------------------------------------------------------
# Augmented BST — the follow-up.
# ----------------------------------------------------------------------

class _OSNode:
    __slots__ = ("val", "left", "right", "size")

    def __init__(self, val: int):
        self.val = val
        self.left: Optional[_OSNode] = None
        self.right: Optional[_OSNode] = None
        self.size = 1  # invariant: 1 + size(left) + size(right)


def _size(n: Optional[_OSNode]) -> int:
    return n.size if n is not None else 0


def _resize(n: _OSNode) -> None:
    n.size = 1 + _size(n.left) + _size(n.right)


class OrderStatBST:
    """Order-statistic BST.  All operations O(H) — H = N worst case (no rebalancing)."""

    def __init__(self) -> None:
        self.root: Optional[_OSNode] = None

    # --- insertion -------------------------------------------------
    def insert(self, v: int) -> None:
        self.root = self._insert(self.root, v)

    def _insert(self, node: Optional[_OSNode], v: int) -> _OSNode:
        if node is None:
            return _OSNode(v)
        if v < node.val:
            node.left = self._insert(node.left, v)
        elif v > node.val:
            node.right = self._insert(node.right, v)
        # duplicate -> no-op (BST property: unique values)
        _resize(node)
        return node

    # --- deletion --------------------------------------------------
    def delete(self, v: int) -> None:
        self.root = self._delete(self.root, v)

    def _delete(self, node: Optional[_OSNode], v: int) -> Optional[_OSNode]:
        if node is None:
            return None
        if v < node.val:
            node.left = self._delete(node.left, v)
        elif v > node.val:
            node.right = self._delete(node.right, v)
        else:
            # Found.  Standard BST delete: 0/1 child → splice; 2 children → swap with inorder successor.
            if node.left is None:
                return node.right
            if node.right is None:
                return node.left
            succ = node.right
            while succ.left is not None:
                succ = succ.left
            node.val = succ.val
            node.right = self._delete(node.right, succ.val)
        _resize(node)
        return node

    # --- queries ---------------------------------------------------
    def __len__(self) -> int:
        return _size(self.root)

    def kth_smallest(self, k: int) -> int:
        if not (1 <= k <= _size(self.root)):
            raise ValueError(f"k={k} out of range [1, {_size(self.root)}]")
        node = self.root
        while node is not None:
            left_sz = _size(node.left)
            if k <= left_sz:
                node = node.left
            elif k == left_sz + 1:
                return node.val
            else:
                k -= left_sz + 1
                node = node.right
        raise RuntimeError("unreachable: k validated above")

    def rank(self, v: int) -> int:
        """Return the 1-indexed rank of v (raises if absent)."""
        node = self.root
        acc = 0
        while node is not None:
            if v < node.val:
                node = node.left
            elif v > node.val:
                acc += _size(node.left) + 1
                node = node.right
            else:
                return acc + _size(node.left) + 1
        raise KeyError(v)

    def __contains__(self, v: int) -> bool:
        node = self.root
        while node is not None:
            if v < node.val:
                node = node.left
            elif v > node.val:
                node = node.right
            else:
                return True
        return False

    # --- invariant audit (used in stress test) ---------------------
    def assert_sizes(self) -> None:
        def go(n: Optional[_OSNode]) -> int:
            if n is None:
                return 0
            l = go(n.left)
            r = go(n.right)
            assert n.size == 1 + l + r, f"size invariant broken at val={n.val}"
            return n.size

        go(self.root)


# ----------------------------------------------------------------------
# Helpers.
# ----------------------------------------------------------------------

def insert_bst(root: Optional[TreeNode], v: int) -> TreeNode:
    if root is None:
        return TreeNode(v)
    if v < root.val:
        root.left = insert_bst(root.left, v)
    elif v > root.val:
        root.right = insert_bst(root.right, v)
    return root


def build_random_bst(rng: random.Random, n: int) -> tuple[Optional[TreeNode], list[int]]:
    if n == 0:
        return None, []
    values = rng.sample(range(-50_000, 50_000), n)
    root: Optional[TreeNode] = None
    for v in values:
        root = insert_bst(root, v)
    return root, sorted(values)


# ----------------------------------------------------------------------
# Stress test.
# ----------------------------------------------------------------------

def stress_test() -> None:
    rng = random.Random(42)

    # --- Part A: base problem — 1000 BSTs, all k, all 3 variants agree.
    a_iter = 1000
    for _ in range(a_iter):
        size = rng.randint(1, 40)
        root, sorted_vals = build_random_bst(rng, size)
        for k in range(1, size + 1):
            a = kth_smallest_iterative(root, k)
            b = kth_smallest_recursive(root, k)
            c = kth_smallest_brute_full_inorder(root, k)
            d = sorted_vals[k - 1]
            assert a == b == c == d, f"Disagreement: {a, b, c, d} for k={k}"

    # --- Part B: augmented BST — random insert/delete/query sequences vs sorted-list oracle.
    b_iter = 500
    for _ in range(b_iter):
        tree = OrderStatBST()
        oracle: list[int] = []  # kept sorted
        present: set[int] = set()
        n_ops = rng.randint(20, 200)
        for _ in range(n_ops):
            op = rng.random()
            if op < 0.5 or not present:
                v = rng.randint(-1000, 1000)
                if v not in present:
                    tree.insert(v)
                    insort(oracle, v)
                    present.add(v)
            elif op < 0.75:
                v = rng.choice(list(present))
                tree.delete(v)
                oracle.remove(v)
                present.remove(v)
            else:
                if oracle:
                    k = rng.randint(1, len(oracle))
                    assert tree.kth_smallest(k) == oracle[k - 1], (
                        f"kth_smallest disagreed: tree={tree.kth_smallest(k)} oracle={oracle[k - 1]} k={k}"
                    )
                    # also verify rank inversion.
                    v = oracle[k - 1]
                    assert tree.rank(v) == k, f"rank disagreed for v={v}"
            tree.assert_sizes()

    print(f"stress_test: {a_iter} base BSTs (all k) + {b_iter} augmented insert/delete/query sequences — all agree.")


# ----------------------------------------------------------------------
# Edge cases.
# ----------------------------------------------------------------------

def edge_cases() -> None:
    # Single node, k=1.
    one = TreeNode(7)
    for fn in (kth_smallest_iterative, kth_smallest_recursive, kth_smallest_brute_full_inorder):
        assert fn(one, 1) == 7

    # LC example: [3,1,4,null,2], k=1 → 1.
    t = TreeNode(3, TreeNode(1, None, TreeNode(2)), TreeNode(4))
    for fn in (kth_smallest_iterative, kth_smallest_recursive, kth_smallest_brute_full_inorder):
        assert fn(t, 1) == 1
        assert fn(t, 2) == 2
        assert fn(t, 3) == 3
        assert fn(t, 4) == 4

    # LC example: [5,3,6,2,4,null,null,1], k=3 → 3.
    t2 = TreeNode(5,
                  TreeNode(3, TreeNode(2, TreeNode(1)), TreeNode(4)),
                  TreeNode(6))
    for fn in (kth_smallest_iterative, kth_smallest_recursive, kth_smallest_brute_full_inorder):
        assert fn(t2, 3) == 3
        assert fn(t2, 6) == 6

    # Left-skewed: [5,4,null,3,null,2,null,1] — values 1..5 inorder.
    ls: Optional[TreeNode] = None
    for v in [5, 4, 3, 2, 1]:
        ls = insert_bst(ls, v)
    for k in range(1, 6):
        for fn in (kth_smallest_iterative, kth_smallest_recursive):
            assert fn(ls, k) == k

    # Right-skewed: 1..10 inorder.
    rs: Optional[TreeNode] = None
    for v in range(1, 11):
        rs = insert_bst(rs, v)
    for k in range(1, 11):
        assert kth_smallest_iterative(rs, k) == k

    # Augmented BST — insert-delete-reinsert sanity.
    tree = OrderStatBST()
    for v in [5, 3, 8, 1, 4, 7, 9]:
        tree.insert(v)
    tree.assert_sizes()
    assert tree.kth_smallest(1) == 1
    assert tree.kth_smallest(4) == 5
    assert tree.kth_smallest(7) == 9
    assert tree.rank(7) == 5
    tree.delete(5)              # root delete
    tree.assert_sizes()
    assert tree.kth_smallest(4) == 7
    tree.insert(5)
    tree.assert_sizes()
    assert tree.kth_smallest(4) == 5
    # delete a leaf.
    tree.delete(1)
    tree.assert_sizes()
    assert tree.kth_smallest(1) == 3
    # delete absent — no-op.
    tree.delete(999)
    tree.assert_sizes()

    print("edge_cases: PASS")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
