"""
p14 — Diameter of Binary Tree
==============================
LeetCode 543 · Easy (mislabeled — Medium in disguise) · Topics: Tree, DFS

The TWO-CHANNEL pattern: the recursion RETURNS height (for parent's use)
while ACCUMULATING the best diameter as a side effect via nonlocal closure.

Sections:
  1. TreeNode + helpers
  2. diameter_brute             — O(N^2) reference oracle
  3. diameter_two_channel       — O(N) with nonlocal best
  4. diameter_two_channel_tuple — O(N) with tuple return (height, best)
  5. stress_test                — brute vs both optimized must agree
  6. edge_cases
"""

import random
import sys
from collections import deque
from typing import List, Optional, Tuple

sys.setrecursionlimit(10_000)


class TreeNode:
    __slots__ = ("val", "left", "right")

    def __init__(
        self,
        val: int = 0,
        left: Optional["TreeNode"] = None,
        right: Optional["TreeNode"] = None,
    ) -> None:
        self.val = val
        self.left = left
        self.right = right


# ----------------------------------------------------------------------
# 2. BRUTE FORCE — O(N^2): height() at every node
# ----------------------------------------------------------------------
def diameter_brute(root: Optional[TreeNode]) -> int:
    def height(node: Optional[TreeNode]) -> int:
        # height in NODE-count: empty=0, leaf=1
        if node is None:
            return 0
        return 1 + max(height(node.left), height(node.right))

    def walk(node: Optional[TreeNode]) -> int:
        if node is None:
            return 0
        here = height(node.left) + height(node.right)  # edges through this node
        return max(here, walk(node.left), walk(node.right))

    return walk(root)


# ----------------------------------------------------------------------
# 3. TWO-CHANNEL OPTIMAL — O(N) with nonlocal best
# ----------------------------------------------------------------------
def diameter_two_channel(root: Optional[TreeNode]) -> int:
    best = 0

    def height(node: Optional[TreeNode]) -> int:
        # INVARIANT: returns NODE-count from node to deepest leaf
        # (empty=0, leaf=1). lh + rh equals the EDGE count of the
        # path bending at `node`, which is the diameter-candidate here.
        nonlocal best
        if node is None:
            return 0
        lh = height(node.left)
        rh = height(node.right)
        # CHANNEL A (side effect): candidate diameter through this node, in edges
        if lh + rh > best:
            best = lh + rh
        # CHANNEL B (return): height for parent's height computation
        return 1 + max(lh, rh)

    height(root)
    return best


# ----------------------------------------------------------------------
# 4. TWO-CHANNEL VIA TUPLE RETURN — equivalent, no closure
# ----------------------------------------------------------------------
def diameter_two_channel_tuple(root: Optional[TreeNode]) -> int:
    def walk(node: Optional[TreeNode]) -> Tuple[int, int]:
        """Returns (height_in_nodes, best_diameter_in_edges_within_this_subtree)."""
        if node is None:
            return 0, 0
        lh, lb = walk(node.left)
        rh, rb = walk(node.right)
        here = lh + rh  # diameter through current node, in edges
        return 1 + max(lh, rh), max(lb, rb, here)

    return walk(root)[1]


# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
def build_tree(values: List[Optional[int]]) -> Optional[TreeNode]:
    if not values or values[0] is None:
        return None
    root = TreeNode(values[0])
    q: deque = deque([root])
    i = 1
    n = len(values)
    while q and i < n:
        cur = q.popleft()
        if i < n and values[i] is not None:
            cur.left = TreeNode(values[i])
            q.append(cur.left)
        i += 1
        if i < n and values[i] is not None:
            cur.right = TreeNode(values[i])
            q.append(cur.right)
        i += 1
    return root


def random_tree(rng: random.Random, max_nodes: int = 40) -> Optional[TreeNode]:
    n = rng.randint(0, max_nodes)
    if n == 0:
        return None
    values: List[Optional[int]] = [rng.randint(-100, 100)]
    for _ in range(n - 1):
        values.append(None if rng.random() < 0.3 else rng.randint(-100, 100))
    return build_tree(values)


# ----------------------------------------------------------------------
# 5. STRESS TEST
# ----------------------------------------------------------------------
def stress_test(iterations: int = 1000) -> None:
    rng = random.Random(42)
    for _ in range(iterations):
        t = random_tree(rng)
        a = diameter_brute(t)
        b = diameter_two_channel(t)
        c = diameter_two_channel_tuple(t)
        assert a == b == c, f"disagreement: brute={a} two_ch={b} tuple={c}"
    print(f"stress_test PASSED — {iterations} iterations (3 variants agree)")


# ----------------------------------------------------------------------
# 6. EDGE CASES
# ----------------------------------------------------------------------
def edge_cases() -> None:
    cases = [
        ([], 0, "empty"),
        ([1], 0, "single node — 0 edges"),
        ([1, 2], 1, "two nodes — 1 edge"),
        ([1, 2, 3], 2, "three-node V — 2 edges"),
        ([1, 2, 3, 4, 5], 3, "LC example — 3 edges (4→2→1→3)"),
        ([1, 2, None, 3, None, 4, None, 5], 4, "left-skewed length-5"),
        # T-shape where longest path does NOT go through root
        ([1, 2, None, 3, 4, None, None, 5, 6, 7, 8], None, "T-shape (computed)"),
    ]
    for vals, expected, label in cases:
        t = build_tree(vals)
        got = diameter_two_channel(t)
        if expected is None:
            print(f"  [INFO] {label}: got={got} (no expected; sanity-check vs brute)")
            brute = diameter_brute(t)
            assert got == brute, f"two_channel disagrees with brute: {got} vs {brute}"
        else:
            status = "PASS" if got == expected else "FAIL"
            print(f"  [{status}] {label}: expected={expected} got={got}")


if __name__ == "__main__":
    print("=== Edge cases ===")
    edge_cases()
    print("\n=== Stress test ===")
    stress_test()
