"""
p48 — Non-overlapping Intervals (LeetCode 435, MEDIUM).

Return MIN intervals to remove so the rest are non-overlapping (half-open semantics).

Implementations:
    erase_overlap_intervals                     — greedy sort-by-end. O(n log n). Canonical.
    erase_overlap_intervals_sort_by_start_WRONG — educational: shows the bug.
    erase_overlap_intervals_dp                  — sort-by-end + DP with bisect. O(n log n).
    erase_overlap_intervals_brute               — try all 2^n subsets. Tiny n only.

Half-open: [1, 2] and [2, 3] do NOT overlap (LC 435 convention).
"""
from __future__ import annotations
import bisect
import random
from itertools import combinations
from typing import List


# --------------------------------------------------------------------------------------
# Canonical: greedy by END. O(n log n).
# INVARIANT: prev_end = end of latest kept interval. Keep current iff start >= prev_end.
# --------------------------------------------------------------------------------------
def erase_overlap_intervals(intervals: List[List[int]]) -> int:
    if not intervals:
        return 0
    arr = sorted(intervals, key=lambda x: x[1])
    removed = 0
    prev_end = float("-inf")
    for s, e in arr:
        if s >= prev_end:
            prev_end = e
        else:
            removed += 1
    return removed


# --------------------------------------------------------------------------------------
# WRONG: sort by start. Kept here as the educational counterexample.
# Fails on [[1, 100], [2, 3], [4, 5]]: returns 2 instead of 1.
# --------------------------------------------------------------------------------------
def erase_overlap_intervals_sort_by_start_WRONG(intervals: List[List[int]]) -> int:
    if not intervals:
        return 0
    arr = sorted(intervals, key=lambda x: x[0])
    removed = 0
    prev_end = float("-inf")
    for s, e in arr:
        if s >= prev_end:
            prev_end = e
        else:
            removed += 1
    return removed


# --------------------------------------------------------------------------------------
# DP variant: sort by end; for each i, dp[i] = max kept ending at-or-before intervals[i].
# Uses bisect to find latest compatible predecessor. O(n log n).
# Answer = n - max(dp).
# --------------------------------------------------------------------------------------
def erase_overlap_intervals_dp(intervals: List[List[int]]) -> int:
    n = len(intervals)
    if n == 0:
        return 0
    arr = sorted(intervals, key=lambda x: x[1])
    ends = [a[1] for a in arr]
    dp = [1] * n
    for i in range(n):
        # find rightmost j with ends[j] <= arr[i][0]  (half-open: end == start ok)
        j = bisect.bisect_right(ends, arr[i][0]) - 1
        if j >= 0:
            dp[i] = max(dp[i], dp[j] + 1)
        if i > 0:
            dp[i] = max(dp[i], dp[i - 1])
    return n - dp[-1]


# --------------------------------------------------------------------------------------
# Brute oracle: try every subset; check disjoint; track max kept. O(2^n * n^2).
# --------------------------------------------------------------------------------------
def erase_overlap_intervals_brute(intervals: List[List[int]]) -> int:
    n = len(intervals)
    if n == 0:
        return 0

    def is_disjoint(subset):
        s = sorted(subset, key=lambda x: x[0])
        for i in range(1, len(s)):
            if s[i][0] < s[i - 1][1]:   # half-open overlap
                return False
        return True

    best_kept = 0
    for k in range(n + 1):
        for combo in combinations(intervals, k):
            if is_disjoint(combo):
                best_kept = max(best_kept, k)
    return n - best_kept


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    assert erase_overlap_intervals([[1, 2], [2, 3], [3, 4], [1, 3]]) == 1
    assert erase_overlap_intervals([[1, 2], [1, 2], [1, 2]]) == 2
    assert erase_overlap_intervals([[1, 2], [2, 3]]) == 0   # half-open touch — no overlap
    assert erase_overlap_intervals([]) == 0
    assert erase_overlap_intervals([[1, 100]]) == 0

    # The classic counterexample: sort-by-start gives 2, sort-by-end gives 1.
    cex = [[1, 100], [2, 3], [4, 5]]
    assert erase_overlap_intervals(cex) == 1
    assert erase_overlap_intervals_sort_by_start_WRONG(cex) == 2   # documents the bug

    # DP agrees with greedy
    assert erase_overlap_intervals_dp([[1, 2], [2, 3], [3, 4], [1, 3]]) == 1
    assert erase_overlap_intervals_dp(cex) == 1
    assert erase_overlap_intervals_dp([]) == 0

    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(500):
        n = rng.randint(0, 8)   # brute is O(2^n) — keep small
        intervals = []
        for _ in range(n):
            a = rng.randint(0, 10)
            b = rng.randint(a + 1, a + 5)
            intervals.append([a, b])

        g = erase_overlap_intervals([list(x) for x in intervals])
        d = erase_overlap_intervals_dp([list(x) for x in intervals])
        b = erase_overlap_intervals_brute([list(x) for x in intervals])

        assert g == d == b, f"trial {trial}: greedy={g} dp={d} brute={b} input={intervals}"

    print("stress_test: 500 random trials — greedy, DP, brute all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
