"""
p12 — Maximum Depth of Binary Tree
===================================
LeetCode 104 · Easy · Topics: Tree, DFS, BFS, Binary Tree

Sections:
  1. TreeNode + build_tree / random_tree helpers
  2. max_depth_recursive  — 3-liner aggregate-from-children
  3. max_depth_bfs        — level-by-level with len(queue) snapshot
  4. max_depth_dfs_iter   — DFS stack with (node, depth) pairs
  5. stress_test          — three variants must agree on 1000 random trees
  6. edge_cases
"""

import random
import sys
from collections import deque
from typing import List, Optional

sys.setrecursionlimit(10_000)


class TreeNode:
    __slots__ = ("val", "left", "right")

    def __init__(
        self,
        val: int = 0,
        left: Optional["TreeNode"] = None,
        right: Optional["TreeNode"] = None,
    ) -> None:
        self.val = val
        self.left = left
        self.right = right


# ----------------------------------------------------------------------
# 2. RECURSIVE — the canonical aggregate-from-children template
# ----------------------------------------------------------------------
def max_depth_recursive(node: Optional[TreeNode]) -> int:
    # BASE: empty subtree has depth 0.
    if node is None:
        return 0
    # COMBINE: 1 for the current node + the deeper of the two subtrees.
    return 1 + max(max_depth_recursive(node.left), max_depth_recursive(node.right))


# ----------------------------------------------------------------------
# 3. BFS — count levels using the len(queue) SNAPSHOT pattern
# ----------------------------------------------------------------------
def max_depth_bfs(node: Optional[TreeNode]) -> int:
    if node is None:
        return 0
    q: deque = deque([node])
    depth = 0
    while q:
        depth += 1
        # SNAPSHOT level size BEFORE appending children — pins this level.
        for _ in range(len(q)):
            cur = q.popleft()
            if cur.left is not None:
                q.append(cur.left)
            if cur.right is not None:
                q.append(cur.right)
    return depth


# ----------------------------------------------------------------------
# 4. ITERATIVE DFS — explicit (node, depth) pairs
# ----------------------------------------------------------------------
def max_depth_dfs_iter(node: Optional[TreeNode]) -> int:
    if node is None:
        return 0
    stack: List = [(node, 1)]
    best = 0
    while stack:
        cur, d = stack.pop()
        if d > best:
            best = d
        if cur.left is not None:
            stack.append((cur.left, d + 1))
        if cur.right is not None:
            stack.append((cur.right, d + 1))
    return best


# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
def build_tree(values: List[Optional[int]]) -> Optional[TreeNode]:
    if not values or values[0] is None:
        return None
    root = TreeNode(values[0])
    q: deque = deque([root])
    i = 1
    n = len(values)
    while q and i < n:
        cur = q.popleft()
        if i < n and values[i] is not None:
            cur.left = TreeNode(values[i])
            q.append(cur.left)
        i += 1
        if i < n and values[i] is not None:
            cur.right = TreeNode(values[i])
            q.append(cur.right)
        i += 1
    return root


def random_tree(rng: random.Random, max_nodes: int = 30) -> Optional[TreeNode]:
    n = rng.randint(0, max_nodes)
    if n == 0:
        return None
    values: List[Optional[int]] = [rng.randint(-100, 100)]
    for _ in range(n - 1):
        values.append(None if rng.random() < 0.3 else rng.randint(-100, 100))
    return build_tree(values)


# ----------------------------------------------------------------------
# 5. STRESS TEST
# ----------------------------------------------------------------------
def stress_test(iterations: int = 1000) -> None:
    rng = random.Random(42)
    for _ in range(iterations):
        t = random_tree(rng)
        a = max_depth_recursive(t)
        b = max_depth_bfs(t)
        c = max_depth_dfs_iter(t)
        assert a == b == c, f"disagreement: recursive={a} bfs={b} dfs={c}"
    print(f"stress_test PASSED — {iterations} iterations (3 variants agree)")


# ----------------------------------------------------------------------
# 6. EDGE CASES
# ----------------------------------------------------------------------
def edge_cases() -> None:
    cases = [
        ([], 0, "empty"),
        ([1], 1, "single node"),
        ([1, 2], 2, "root + left only"),
        ([1, None, 2], 2, "root + right only"),
        ([3, 9, 20, None, None, 15, 7], 3, "canonical LC example"),
        ([1, 2, 3, 4, 5, 6, 7], 3, "perfect tree depth 3"),
        ([1, 2, None, 3, None, 4, None, 5], 5, "left-skewed depth 5"),
    ]
    for vals, expected, label in cases:
        tree = build_tree(vals)
        got = max_depth_recursive(tree)
        status = "PASS" if got == expected else "FAIL"
        print(f"  [{status}] {label}: expected={expected} got={got}")


if __name__ == "__main__":
    print("=== Edge cases ===")
    edge_cases()
    print("\n=== Stress test ===")
    stress_test()
