"""
LC 98 — Validate Binary Search Tree.

Four implementations, all O(N) time except the brute force (O(N^2)):
    1. validate_bounds            — recursive bounds-passing with None sentinels.
    2. validate_inorder_iterative — iterative inorder with `prev`; stack-safe.
    3. validate_inorder_recursive — recursive inorder with nonlocal `prev`.
    4. validate_brute_n2          — O(N^2) oracle: at every node, scan subtrees.

INVARIANT (BST): for every node, ALL values in its left subtree are strictly
less than node.val, AND ALL values in its right subtree are strictly greater.
"""

from __future__ import annotations

import random
import sys
from collections import deque
from typing import Optional

sys.setrecursionlimit(10_000)


class TreeNode:
    __slots__ = ("val", "left", "right")

    def __init__(self, val: int = 0, left: Optional["TreeNode"] = None,
                 right: Optional["TreeNode"] = None):
        self.val = val
        self.left = left
        self.right = right


# ----------------------------------------------------------------------
# Solution A — bounds-passing recursion (production-recommended for clarity).
# ----------------------------------------------------------------------

def validate_bounds(root: Optional[TreeNode]) -> bool:
    """Bounds-passing recursion. None sentinels = unbounded side."""

    def go(node: Optional[TreeNode], lo: Optional[int], hi: Optional[int]) -> bool:
        if node is None:
            return True
        # INVARIANT: every ancestor's constraint on this subtree is encoded in (lo, hi).
        if lo is not None and node.val <= lo:
            return False
        if hi is not None and node.val >= hi:
            return False
        return go(node.left, lo, node.val) and go(node.right, node.val, hi)

    return go(root, None, None)


# ----------------------------------------------------------------------
# Solution B — iterative inorder, single `prev` (the shippable form).
# ----------------------------------------------------------------------

def validate_inorder_iterative(root: Optional[TreeNode]) -> bool:
    """Iterative inorder; reject when a value is not strictly greater than its predecessor."""
    stack: list[TreeNode] = []
    cur = root
    prev: Optional[int] = None
    while cur is not None or stack:
        # Descend leftmost path, pushing as we go.
        while cur is not None:
            stack.append(cur)
            cur = cur.left
        node = stack.pop()
        # INVARIANT: inorder of a valid BST is strictly increasing.
        if prev is not None and node.val <= prev:
            return False
        prev = node.val
        cur = node.right
    return True


# ----------------------------------------------------------------------
# Solution C — recursive inorder with nonlocal `prev` (compact form).
# ----------------------------------------------------------------------

def validate_inorder_recursive(root: Optional[TreeNode]) -> bool:
    prev: Optional[int] = None
    ok = True

    def go(node: Optional[TreeNode]) -> None:
        nonlocal prev, ok
        if node is None or not ok:
            return
        go(node.left)
        if prev is not None and node.val <= prev:
            ok = False
            return
        prev = node.val
        go(node.right)

    go(root)
    return ok


# ----------------------------------------------------------------------
# Brute force — O(N^2) oracle.
# ----------------------------------------------------------------------

def validate_brute_n2(root: Optional[TreeNode]) -> bool:
    """At every node, walk entire left/right subtree and verify all values are < / >."""

    def subtree_values(node: Optional[TreeNode]) -> list[int]:
        out: list[int] = []
        if node is None:
            return out
        stack = [node]
        while stack:
            n = stack.pop()
            out.append(n.val)
            if n.left:
                stack.append(n.left)
            if n.right:
                stack.append(n.right)
        return out

    def go(node: Optional[TreeNode]) -> bool:
        if node is None:
            return True
        for v in subtree_values(node.left):
            if v >= node.val:
                return False
        for v in subtree_values(node.right):
            if v <= node.val:
                return False
        return go(node.left) and go(node.right)

    return go(root)


# ----------------------------------------------------------------------
# Tree-building helpers and random generators.
# ----------------------------------------------------------------------

def build_tree_from_list(vals: list[Optional[int]]) -> Optional[TreeNode]:
    """LeetCode-style level-order list with None for missing children."""
    if not vals:
        return None
    it = iter(vals)
    first = next(it)
    if first is None:
        return None
    root = TreeNode(first)
    q: deque[TreeNode] = deque([root])
    is_left = True
    while q:
        try:
            v = next(it)
        except StopIteration:
            break
        parent = q[0]
        if is_left:
            if v is not None:
                parent.left = TreeNode(v)
                q.append(parent.left)
            is_left = False
        else:
            if v is not None:
                parent.right = TreeNode(v)
                q.append(parent.right)
            q.popleft()
            is_left = True
    return root


def build_random_bst(rng: random.Random, n: int) -> Optional[TreeNode]:
    """Insert n distinct random ints into a BST in random order — produces a VALID BST."""
    if n == 0:
        return None
    values = rng.sample(range(-10_000, 10_000), n)
    root: Optional[TreeNode] = None

    def insert(node: Optional[TreeNode], v: int) -> TreeNode:
        if node is None:
            return TreeNode(v)
        if v < node.val:
            node.left = insert(node.left, v)
        else:
            node.right = insert(node.right, v)
        return node

    for v in values:
        root = insert(root, v)
    return root


def build_random_binary_tree(rng: random.Random, n: int) -> Optional[TreeNode]:
    """Random shape, random values — likely NOT a valid BST."""
    if n == 0:
        return None
    nodes = [TreeNode(rng.randint(-100, 100)) for _ in range(n)]
    for i in range(1, n):
        parent = nodes[rng.randrange(i)]
        if parent.left is None and parent.right is None:
            if rng.random() < 0.5:
                parent.left = nodes[i]
            else:
                parent.right = nodes[i]
        elif parent.left is None:
            parent.left = nodes[i]
        elif parent.right is None:
            parent.right = nodes[i]
        else:
            # Fall back to a deeper free slot.
            for cand in nodes[:i]:
                if cand.left is None:
                    cand.left = nodes[i]
                    break
                if cand.right is None:
                    cand.right = nodes[i]
                    break
    return nodes[0]


def mutate_swap_two_values(rng: random.Random, root: TreeNode) -> None:
    """Pick two nodes and swap their values (almost always breaks BST property)."""
    bag: list[TreeNode] = []
    stack = [root]
    while stack:
        n = stack.pop()
        bag.append(n)
        if n.left:
            stack.append(n.left)
        if n.right:
            stack.append(n.right)
    if len(bag) < 2:
        return
    a, b = rng.sample(bag, 2)
    a.val, b.val = b.val, a.val


# ----------------------------------------------------------------------
# Stress test.
# ----------------------------------------------------------------------

def stress_test() -> None:
    rng = random.Random(42)
    n_iter = 1000

    # Part A: 500 valid BSTs — all 4 implementations should return True.
    valid_count = 0
    for _ in range(n_iter // 2):
        size = rng.randint(0, 30)
        root = build_random_bst(rng, size)
        results = (
            validate_bounds(root),
            validate_inorder_iterative(root),
            validate_inorder_recursive(root),
            validate_brute_n2(root),
        )
        assert all(results), f"Valid BST rejected by some variant: {results}"
        if all(results):
            valid_count += 1

    # Part B: 500 swap-mutated trees — should mostly be False; whatever the answer, all 4 must agree.
    agree_invalid = 0
    for _ in range(n_iter // 2):
        size = rng.randint(2, 30)
        root = build_random_bst(rng, size)
        assert root is not None
        mutate_swap_two_values(rng, root)
        a = validate_bounds(root)
        b = validate_inorder_iterative(root)
        c = validate_inorder_recursive(root)
        d = validate_brute_n2(root)
        assert a == b == c == d, f"Disagreement on mutated tree: {(a, b, c, d)}"
        if not a:
            agree_invalid += 1

    # Part C: random non-BST trees — all 4 must agree.
    for _ in range(n_iter // 2):
        size = rng.randint(1, 30)
        root = build_random_binary_tree(rng, size)
        a = validate_bounds(root)
        b = validate_inorder_iterative(root)
        c = validate_inorder_recursive(root)
        d = validate_brute_n2(root)
        assert a == b == c == d, f"Disagreement on random tree: {(a, b, c, d)}"

    print(f"stress_test: {valid_count} valid + {agree_invalid} invalid (mutated) + 500 random  — all 4 variants agree.")


# ----------------------------------------------------------------------
# Edge cases.
# ----------------------------------------------------------------------

def edge_cases() -> None:
    # Empty.
    assert validate_bounds(None) is True
    assert validate_inorder_iterative(None) is True
    assert validate_inorder_recursive(None) is True
    assert validate_brute_n2(None) is True

    # Single node.
    one = TreeNode(7)
    for fn in (validate_bounds, validate_inorder_iterative, validate_inorder_recursive, validate_brute_n2):
        assert fn(one) is True

    # The canonical [2,1,3] valid.
    t = build_tree_from_list([2, 1, 3])
    for fn in (validate_bounds, validate_inorder_iterative, validate_inorder_recursive, validate_brute_n2):
        assert fn(t) is True

    # The canonical trap: [5,1,4,null,null,3,6] — INVALID.
    trap = TreeNode(5,
                    TreeNode(1),
                    TreeNode(4, TreeNode(3), TreeNode(6)))
    for fn in (validate_bounds, validate_inorder_iterative, validate_inorder_recursive, validate_brute_n2):
        assert fn(trap) is False, f"{fn.__name__} failed the canonical trap"

    # Strict: duplicates invalid.
    dup = TreeNode(1, TreeNode(1))
    for fn in (validate_bounds, validate_inorder_iterative, validate_inorder_recursive, validate_brute_n2):
        assert fn(dup) is False

    # INT_MAX boundary.
    big = TreeNode(2_147_483_647)
    for fn in (validate_bounds, validate_inorder_iterative, validate_inorder_recursive, validate_brute_n2):
        assert fn(big) is True

    # INT_MIN boundary (LC test that historically broke `float('-inf')` solutions in some languages).
    small = TreeNode(-2_147_483_648)
    for fn in (validate_bounds, validate_inorder_iterative, validate_inorder_recursive, validate_brute_n2):
        assert fn(small) is True

    # Left-skewed valid: 3-2-1.
    left_skewed = TreeNode(3, TreeNode(2, TreeNode(1)))
    for fn in (validate_bounds, validate_inorder_iterative, validate_inorder_recursive, validate_brute_n2):
        assert fn(left_skewed) is True

    print("edge_cases: PASS")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
