"""
p41 — Kth Largest Element in an Array (LeetCode 215, MEDIUM).

Implementations:
    find_kth_largest_heap         — min-heap of size K. O(n log K). Production default.
    find_kth_largest_quickselect  — random-pivot 3-way split. O(n) expected.
    find_kth_largest_sort         — sort baseline. O(n log n).
    find_kth_largest_nlargest     — heapq.nlargest library shortcut.
    KthLargest                    — LC 703 streaming class.

Stress test: 1000 random arrays, all four batch implementations agree.

Run: python3 solution.py
"""
from __future__ import annotations
import heapq
import random
from typing import List


# --------------------------------------------------------------------------------------
# Min-heap of size K. The root is the smallest-of-K, i.e., the current K-th largest.
# INVARIANT: heap holds exactly K largest seen so far (or fewer if we haven't seen K).
# --------------------------------------------------------------------------------------
def find_kth_largest_heap(nums: List[int], k: int) -> int:
    heap: List[int] = []
    for x in nums:
        if len(heap) < k:
            heapq.heappush(heap, x)
        elif x > heap[0]:
            heapq.heappushpop(heap, x)   # atomic push-then-pop; one O(log K)
    return heap[0]


# --------------------------------------------------------------------------------------
# Quickselect with random pivot + 3-way (Dutch flag) split.
# Targets 0-indexed position (n - k) — i.e., K-th largest by rank.
# Expected O(n); worst O(n²) only if pivot is adversarial (random pivot defeats this).
# --------------------------------------------------------------------------------------
def find_kth_largest_quickselect(nums: List[int], k: int) -> int:
    target = len(nums) - k  # 0-indexed position in ascending-sorted order

    def select(arr: List[int], t: int) -> int:
        if len(arr) == 1:
            return arr[0]
        pivot = random.choice(arr)
        # 3-way split: handles duplicates without quadratic blowup
        lows = [x for x in arr if x < pivot]
        eqs  = [x for x in arr if x == pivot]
        his  = [x for x in arr if x > pivot]
        if t < len(lows):
            return select(lows, t)
        elif t < len(lows) + len(eqs):
            return pivot
        else:
            return select(his, t - len(lows) - len(eqs))

    # Copy so we don't mutate the caller's input via the list-comp variant (here we don't,
    # but document the contract: this function does NOT mutate.)
    return select(list(nums), target)


# --------------------------------------------------------------------------------------
# Sort baseline. Always mention as baseline; never as primary in an interview.
# --------------------------------------------------------------------------------------
def find_kth_largest_sort(nums: List[int], k: int) -> int:
    return sorted(nums)[-k]


# --------------------------------------------------------------------------------------
# Library shortcut. Under the hood, heapq.nlargest uses heap-of-size-K.
# --------------------------------------------------------------------------------------
def find_kth_largest_nlargest(nums: List[int], k: int) -> int:
    return heapq.nlargest(k, nums)[-1]


# --------------------------------------------------------------------------------------
# LC 703: streaming variant. Heap-of-size-K is the only viable approach.
# --------------------------------------------------------------------------------------
class KthLargest:
    def __init__(self, k: int, nums: List[int]):
        self.k = k
        self.heap: List[int] = []
        for x in nums:
            self.add(x)

    def add(self, val: int) -> int:
        if len(self.heap) < self.k:
            heapq.heappush(self.heap, val)
        elif val > self.heap[0]:
            heapq.heappushpop(self.heap, val)
        return self.heap[0]


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    # LC canonical examples
    assert find_kth_largest_heap([3, 2, 1, 5, 6, 4], 2) == 5
    assert find_kth_largest_heap([3, 2, 3, 1, 2, 4, 5, 5, 6], 4) == 4

    # k = 1 (max), k = n (min)
    assert find_kth_largest_heap([1, 2, 3, 4, 5], 1) == 5
    assert find_kth_largest_heap([1, 2, 3, 4, 5], 5) == 1

    # All duplicates — K-th by rank, not distinct
    assert find_kth_largest_heap([7, 7, 7], 2) == 7
    assert find_kth_largest_quickselect([7, 7, 7], 2) == 7

    # Single element
    assert find_kth_largest_heap([42], 1) == 42

    # Negatives
    assert find_kth_largest_heap([-1, -5, -2, -3], 1) == -1
    assert find_kth_largest_heap([-1, -5, -2, -3], 4) == -5

    # Streaming class — LC 703 canonical
    kth = KthLargest(3, [4, 5, 8, 2])
    assert kth.add(3) == 4
    assert kth.add(5) == 5
    assert kth.add(10) == 5
    assert kth.add(9) == 8
    assert kth.add(4) == 8

    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for _ in range(1000):
        n = rng.randint(1, 30)
        nums = [rng.randint(-50, 50) for _ in range(n)]
        k = rng.randint(1, n)
        a = find_kth_largest_heap(nums, k)
        b = find_kth_largest_quickselect(nums, k)
        c = find_kth_largest_sort(nums, k)
        d = find_kth_largest_nlargest(nums, k)
        assert a == b == c == d, f"DISAGREE nums={nums} k={k}: heap={a} qs={b} sort={c} nl={d}"

    # Adversarial: all duplicates, sorted, reverse-sorted
    for n in (50, 100):
        for k in (1, n // 2, n):
            arr1 = [5] * n
            arr2 = list(range(n))
            arr3 = list(range(n, 0, -1))
            for arr in (arr1, arr2, arr3):
                a = find_kth_largest_heap(arr, k)
                b = find_kth_largest_quickselect(arr, k)
                c = find_kth_largest_sort(arr, k)
                assert a == b == c, f"adversarial fail arr={arr[:5]}... k={k}"

    # Streaming class — random sequence
    rng2 = random.Random(99)
    for _ in range(100):
        k = rng2.randint(1, 5)
        init = [rng2.randint(-100, 100) for _ in range(rng2.randint(0, 10))]
        kth = KthLargest(k, init)
        seen = list(init)
        for _ in range(20):
            v = rng2.randint(-100, 100)
            got = kth.add(v)
            seen.append(v)
            if len(seen) >= k:
                expected = sorted(seen)[-k]
                assert got == expected, f"streaming fail seen={seen} k={k}: got {got}, want {expected}"

    print("stress_test: 1000 random arrays + adversarial + 100 streaming sequences — all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
