"""
p43 — Find Median from Data Stream (LeetCode 295, HARD).

Two-heap dance:
    lo: max-heap of smaller half (negated in Python).
    hi: min-heap of larger half.
Invariants (after every addNum):
    1. |lo| in {|hi|, |hi| + 1}.
    2. max(lo) <= min(hi).

Median: if |lo| > |hi|: top(lo); else avg(top(lo), top(hi)).

Stress: 5000 random add/findMedian sequences vs sorted-list oracle.
"""
from __future__ import annotations
import heapq
import random
from bisect import insort
from typing import List


# --------------------------------------------------------------------------------------
# Canonical: lo-then-move-to-hi-then-rebalance.
# --------------------------------------------------------------------------------------
class MedianFinder:
    def __init__(self) -> None:
        self.lo: List[int] = []   # max-heap (negated)
        self.hi: List[int] = []   # min-heap

    def addNum(self, num: int) -> None:
        # Step 1: push to lo (negate).
        heapq.heappush(self.lo, -num)
        # Step 2: move lo's max to hi (preserves cross-heap ordering).
        heapq.heappush(self.hi, -heapq.heappop(self.lo))
        # Step 3: rebalance sizes (lo must be ≥ hi in size).
        if len(self.hi) > len(self.lo):
            heapq.heappush(self.lo, -heapq.heappop(self.hi))

    def findMedian(self) -> float:
        if len(self.lo) > len(self.hi):
            return float(-self.lo[0])
        return (-self.lo[0] + self.hi[0]) / 2.0


# --------------------------------------------------------------------------------------
# Alternative: branch-on-comparison. Same big-O, fewer ops in steady state.
# --------------------------------------------------------------------------------------
class MedianFinderBranching:
    def __init__(self) -> None:
        self.lo: List[int] = []
        self.hi: List[int] = []

    def addNum(self, num: int) -> None:
        if not self.lo or num <= -self.lo[0]:
            heapq.heappush(self.lo, -num)
        else:
            heapq.heappush(self.hi, num)
        # Rebalance both directions.
        if len(self.lo) > len(self.hi) + 1:
            heapq.heappush(self.hi, -heapq.heappop(self.lo))
        elif len(self.hi) > len(self.lo):
            heapq.heappush(self.lo, -heapq.heappop(self.hi))

    def findMedian(self) -> float:
        if len(self.lo) > len(self.hi):
            return float(-self.lo[0])
        return (-self.lo[0] + self.hi[0]) / 2.0


# --------------------------------------------------------------------------------------
# Oracle: sorted-list baseline.
# --------------------------------------------------------------------------------------
class MedianFinderSorted:
    def __init__(self) -> None:
        self.a: List[int] = []

    def addNum(self, num: int) -> None:
        insort(self.a, num)

    def findMedian(self) -> float:
        n = len(self.a)
        if n % 2 == 1:
            return float(self.a[n // 2])
        return (self.a[n // 2 - 1] + self.a[n // 2]) / 2.0


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    # LC canonical
    mf = MedianFinder()
    mf.addNum(1)
    assert mf.findMedian() == 1.0
    mf.addNum(2)
    assert mf.findMedian() == 1.5
    mf.addNum(3)
    assert mf.findMedian() == 2.0

    # All duplicates
    mf = MedianFinder()
    for _ in range(5):
        mf.addNum(7)
    assert mf.findMedian() == 7.0

    # Negatives + zero
    mf = MedianFinder()
    for x in [-5, 0, 5]:
        mf.addNum(x)
    assert mf.findMedian() == 0.0

    # Even count
    mf = MedianFinder()
    for x in [4, 2]:
        mf.addNum(x)
    assert mf.findMedian() == 3.0

    # Branching variant
    mf2 = MedianFinderBranching()
    for x in [1, 2, 3, 4, 5]:
        mf2.addNum(x)
    assert mf2.findMedian() == 3.0

    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(50):
        a = MedianFinder()
        b = MedianFinderBranching()
        oracle = MedianFinderSorted()
        n_ops = rng.randint(1, 200)
        for _ in range(n_ops):
            if rng.random() < 0.7 or len(oracle.a) == 0:
                # addNum
                x = rng.randint(-1000, 1000)
                a.addNum(x); b.addNum(x); oracle.addNum(x)
            else:
                # findMedian — compare all three
                ma, mb, mo = a.findMedian(), b.findMedian(), oracle.findMedian()
                assert ma == mo, f"trial {trial}: canonical {ma} != oracle {mo}, arr={oracle.a}"
                assert mb == mo, f"trial {trial}: branching {mb} != oracle {mo}, arr={oracle.a}"

    print("stress_test: 50 trials × up to 200 random ops — both heaps agree with sorted oracle.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
