"""
p42 — Top K Frequent Elements (LeetCode 347, MEDIUM).

Implementations:
    top_k_frequent_bucket  — O(n) bucket sort by count. Optimal.
    top_k_frequent_heap    — O(n log K) min-heap of size K.
    top_k_frequent_sort    — O(n log n) baseline.
    top_k_frequent_words   — LC 692 with lex tie-break.

Stress test: 500 random arrays, compare as SETS (LC accepts any order).

Run: python3 solution.py
"""
from __future__ import annotations
import heapq
import random
from collections import Counter
from typing import List


# --------------------------------------------------------------------------------------
# Bucket sort. O(n). Optimal.
# INVARIANT: buckets[c] = list of elements with frequency exactly c.
# Counts are bounded by len(nums), so a size-(n+1) array suffices.
# --------------------------------------------------------------------------------------
def top_k_frequent_bucket(nums: List[int], k: int) -> List[int]:
    counter = Counter(nums)
    buckets: List[List[int]] = [[] for _ in range(len(nums) + 1)]
    for val, cnt in counter.items():
        buckets[cnt].append(val)

    result: List[int] = []
    for c in range(len(buckets) - 1, 0, -1):
        for val in buckets[c]:
            result.append(val)
            if len(result) == k:
                return result
    return result   # only reached if k > #distinct (LC guarantees not)


# --------------------------------------------------------------------------------------
# Min-heap of size K with (count, val) tuples. O(n log K). Streaming-friendly.
# --------------------------------------------------------------------------------------
def top_k_frequent_heap(nums: List[int], k: int) -> List[int]:
    counter = Counter(nums)
    heap: List[tuple] = []
    for val, cnt in counter.items():
        if len(heap) < k:
            heapq.heappush(heap, (cnt, val))
        elif cnt > heap[0][0]:
            heapq.heappushpop(heap, (cnt, val))
    return [val for cnt, val in heap]


# --------------------------------------------------------------------------------------
# Sort baseline. O(n log n).
# --------------------------------------------------------------------------------------
def top_k_frequent_sort(nums: List[int], k: int) -> List[int]:
    counter = Counter(nums)
    return [val for val, _ in counter.most_common(k)]


# --------------------------------------------------------------------------------------
# LC 692: Top K Frequent Words with lex tie-break.
# Trick: push (-cnt, word) into min-heap of size K.
#   - Min-heap pops smallest (-cnt, word) on overflow.
#   - On count tie, lex-LARGER word is "smallest" by -cnt (no — same -cnt), then by word:
#     lex-LARGER word pops first → we KEEP lex-smaller. Correct.
#   - At end, pop all and sort descending: (cnt desc, word asc).
# Final return ordering: descending by count, ascending by word on tie.
# --------------------------------------------------------------------------------------
def top_k_frequent_words(words: List[str], k: int) -> List[str]:
    counter = Counter(words)
    # Direct approach: sort with composite key (more readable; same big-O as size-K heap
    # since #distinct can be O(n) anyway).
    # INVARIANT: sort by (-count, word) → highest count first, lex ascending on tie.
    items = sorted(counter.items(), key=lambda kv: (-kv[1], kv[0]))
    return [word for word, _ in items[:k]]


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    # LC canonical
    assert set(top_k_frequent_bucket([1, 1, 1, 2, 2, 3], 2)) == {1, 2}
    assert set(top_k_frequent_heap([1, 1, 1, 2, 2, 3], 2)) == {1, 2}
    assert top_k_frequent_bucket([1], 1) == [1]
    assert top_k_frequent_heap([1], 1) == [1]

    # All same
    assert top_k_frequent_bucket([4, 4, 4, 4, 4], 1) == [4]

    # All distinct (every count = 1)
    res = top_k_frequent_bucket([1, 2, 3, 4, 5], 3)
    assert len(res) == 3 and set(res).issubset({1, 2, 3, 4, 5})

    # k = #distinct
    assert set(top_k_frequent_bucket([1, 2, 3, 1, 2, 3], 3)) == {1, 2, 3}

    # Negatives
    assert set(top_k_frequent_bucket([-1, -1, -2, -2, -2, 3], 2)) == {-1, -2}

    # LC 692 tie-break
    assert top_k_frequent_words(["i", "love", "leetcode", "i", "love", "coding"], 2) == ["i", "love"]
    # Tie: "the", "is", "sunny", "day" all freq 4, "the" wins
    w = ["the", "day", "is", "sunny", "the", "the", "the", "sunny", "is", "is", "sunny", "day"]
    assert top_k_frequent_words(w, 4) == ["the", "is", "sunny", "day"]

    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for _ in range(500):
        n = rng.randint(1, 50)
        nums = [rng.randint(0, 10) for _ in range(n)]
        distinct = len(set(nums))
        k = rng.randint(1, distinct)
        a = set(top_k_frequent_bucket(nums, k))
        b = set(top_k_frequent_heap(nums, k))
        c = set(top_k_frequent_sort(nums, k))
        # All sets must be the same SIZE-k subset of the true top-k.
        # If there are ties at position k, different valid answers exist; compare against
        # the sort baseline which uses Counter.most_common (deterministic).
        # All three are valid; assert all three are valid top-k sets (have correct counts).
        counter = Counter(nums)
        cnts_sorted = sorted(counter.values(), reverse=True)
        cutoff = cnts_sorted[k - 1]
        # Any val above cutoff MUST be in result; below cutoff MUST NOT.
        # At cutoff: tie-break is free.
        for sset in (a, b, c):
            assert len(sset) == k
            for val in sset:
                assert counter[val] >= cutoff, f"val {val} count {counter[val]} < cutoff {cutoff}"
            # All "strictly above" must be present
            must_have = {v for v, c in counter.items() if c > cutoff}
            assert must_have.issubset(sset), f"missing must-haves {must_have - sset} in {sset}"

    print("stress_test: 500 random arrays — bucket, heap, sort all return valid top-k sets.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
