"""
p83 — Task Scheduler (LC 621)

INVARIANT (frame model): Place the most-frequent task as anchors at distance n+1.
  Skeleton uses (f_max - 1) * (n + 1) + 1 slots, plus +c extra for c-1 tasks tying f_max.
  If many distinct tasks fill all idle slots, answer is just len(tasks).

CLOSED FORM: answer = max(len(tasks), (f_max - 1) * (n + 1) + c)
"""
from __future__ import annotations

import heapq
import random
from collections import Counter, deque
from itertools import permutations
from typing import List


def least_interval_closed(tasks: List[str], n: int) -> int:
    """O(N + K) closed form."""
    if not tasks:
        return 0
    counts = Counter(tasks)
    f_max = max(counts.values())
    c = sum(1 for v in counts.values() if v == f_max)
    return max(len(tasks), (f_max - 1) * (n + 1) + c)


def least_interval_heap(tasks: List[str], n: int) -> int:
    """O(N log K) max-heap simulation. Negates counts since heapq is min-heap."""
    if not tasks:
        return 0
    counts = Counter(tasks)
    heap = [-v for v in counts.values()]
    heapq.heapify(heap)
    time = 0
    while heap:
        # Each cycle: try to run n+1 distinct tasks (highest counts).
        temp: List[int] = []
        cycle = n + 1
        ran = 0
        while cycle > 0 and heap:
            cnt = -heapq.heappop(heap)
            if cnt > 1:
                temp.append(-(cnt - 1))
            cycle -= 1
            ran += 1
        for x in temp:
            heapq.heappush(heap, x)
        # If heap is non-empty, we must wait the full n+1 cycle (idle slots).
        # If heap is empty, we only ran `ran` tasks this cycle — no trailing idle.
        time += (n + 1) if heap else ran
    return time


def least_interval_brute(tasks: List[str], n: int) -> int:
    """Brute oracle: try permutations + idle insertions. Only for small inputs (len <= 6)."""
    # For each permutation of tasks, simulate with greedy idle insertion;
    # take the minimum total time.
    best = float("inf")
    seen_perms = set()
    for perm in permutations(tasks):
        if perm in seen_perms:
            continue
        seen_perms.add(perm)
        # Simulate this fixed order with mandatory idles.
        last_seen: dict[str, int] = {}
        t = 0
        for task in perm:
            if task in last_seen:
                # Must wait until last_seen[task] + n + 1.
                t = max(t, last_seen[task] + n + 1)
            last_seen[task] = t
            t += 1
        if t < best:
            best = t
    return int(best)


def edge_cases() -> None:
    assert least_interval_closed(list("AAABBB"), 2) == 8
    assert least_interval_heap(list("AAABBB"), 2) == 8

    # n = 0 → no cooldown.
    assert least_interval_closed(list("AAABBB"), 0) == 6
    assert least_interval_heap(list("AAABBB"), 0) == 6

    # Single task.
    assert least_interval_closed(["A"], 5) == 1
    assert least_interval_heap(["A"], 5) == 1

    # All same.
    assert least_interval_closed(list("AAAA"), 2) == 10  # A_ _ A _ _ A _ _ A
    assert least_interval_heap(list("AAAA"), 2) == 10

    # Many distinct, small n.
    assert least_interval_closed(list("ABCDEF"), 1) == 6
    assert least_interval_heap(list("ABCDEF"), 1) == 6

    # Ties at f_max: AAABBBCC, n=2 → A B C A B C A B _ _ _ ? len=8, (3-1)*3 + 2 = 8.
    assert least_interval_closed(list("AAABBBCC"), 2) == 8


def stress_test() -> None:
    rng = random.Random(42)
    for _ in range(200):
        length = rng.randint(1, 6)
        tasks = [chr(ord("A") + rng.randint(0, 3)) for _ in range(length)]
        n = rng.randint(0, 4)

        closed = least_interval_closed(tasks[:], n)
        heap_ans = least_interval_heap(tasks[:], n)
        brute = least_interval_brute(tasks[:], n)
        assert closed == heap_ans == brute, (tasks, n, closed, heap_ans, brute)


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
