"""
p47 — Insert Interval (LeetCode 57, MEDIUM).

Three implementations:
    insert_interval         — three-phase O(n) scan. Canonical.
    insert_interval_lazy    — append + sort + merge (oracle). O(n log n).
    insert_interval_binary  — bisect boundaries + merge slice. O(log n + k).

Closed-interval semantics: [1, 2] and [2, 3] merge.

Stress: 1000 random pre-sorted-disjoint inputs + random new; all three agree.
"""
from __future__ import annotations
import bisect
import random
from typing import List


# --------------------------------------------------------------------------------------
# Canonical: three-phase scan. O(n).
# INVARIANT: intervals are sorted by start AND non-overlapping.
# Phases: (1) strictly before new, (2) merge with new, (3) strictly after new.
# --------------------------------------------------------------------------------------
def insert_interval(intervals: List[List[int]], new: List[int]) -> List[List[int]]:
    new = list(new)   # avoid mutating caller's interval
    out: List[List[int]] = []
    i, n = 0, len(intervals)

    # Phase 1: strictly before new (end < new.start)
    while i < n and intervals[i][1] < new[0]:
        out.append(intervals[i])
        i += 1

    # Phase 2: overlapping (start <= new.end) — expand new
    while i < n and intervals[i][0] <= new[1]:
        new[0] = min(new[0], intervals[i][0])
        new[1] = max(new[1], intervals[i][1])
        i += 1
    out.append(new)

    # Phase 3: strictly after
    out.extend(intervals[i:])
    return out


# --------------------------------------------------------------------------------------
# Lazy oracle: append + sort + merge (reuse p46 logic).
# --------------------------------------------------------------------------------------
def insert_interval_lazy(intervals: List[List[int]], new: List[int]) -> List[List[int]]:
    arr = sorted([list(x) for x in intervals] + [list(new)], key=lambda x: x[0])
    out: List[List[int]] = [arr[0]]
    for s, e in arr[1:]:
        if s <= out[-1][1]:
            out[-1][1] = max(out[-1][1], e)
        else:
            out.append([s, e])
    return out


# --------------------------------------------------------------------------------------
# Binary-search boundaries + merge the affected slice. O(log n + k).
# Pre-extract ends/starts arrays; bisect to find merge-window boundaries.
# --------------------------------------------------------------------------------------
def insert_interval_binary(intervals: List[List[int]], new: List[int]) -> List[List[int]]:
    if not intervals:
        return [list(new)]
    ns, ne = new
    ends = [i[1] for i in intervals]
    starts = [i[0] for i in intervals]

    # First interval whose end >= ns  ⇒  start of merge window
    lo = bisect.bisect_left(ends, ns)
    # First interval whose start > ne  ⇒  one past end of merge window
    hi = bisect.bisect_right(starts, ne)

    if lo == hi:
        # No overlap; new fits strictly between intervals[lo-1] and intervals[lo]
        return [list(x) for x in intervals[:lo]] + [list(new)] + [list(x) for x in intervals[lo:]]

    merged_start = min(ns, intervals[lo][0])
    merged_end = max(ne, intervals[hi - 1][1])
    return (
        [list(x) for x in intervals[:lo]]
        + [[merged_start, merged_end]]
        + [list(x) for x in intervals[hi:]]
    )


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    assert insert_interval([[1, 3], [6, 9]], [2, 5]) == [[1, 5], [6, 9]]
    assert insert_interval(
        [[1, 2], [3, 5], [6, 7], [8, 10], [12, 16]], [4, 8]
    ) == [[1, 2], [3, 10], [12, 16]]
    assert insert_interval([], [5, 7]) == [[5, 7]]
    assert insert_interval([[1, 5]], [2, 3]) == [[1, 5]]            # contained
    assert insert_interval([[1, 5]], [6, 8]) == [[1, 5], [6, 8]]    # after
    assert insert_interval([[5, 8]], [1, 2]) == [[1, 2], [5, 8]]    # before
    assert insert_interval([[1, 5]], [5, 7]) == [[1, 7]]            # touching merge
    assert insert_interval([[3, 5], [12, 15]], [6, 6]) == [[3, 5], [6, 6], [12, 15]]
    assert insert_interval([[1, 5]], [0, 10]) == [[0, 10]]          # new contains all

    # Binary variant agreement
    assert insert_interval_binary([[1, 3], [6, 9]], [2, 5]) == [[1, 5], [6, 9]]
    assert insert_interval_binary([], [5, 7]) == [[5, 7]]
    assert insert_interval_binary([[1, 5]], [0, 10]) == [[0, 10]]

    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(1000):
        # Build a random sorted-disjoint interval list
        n = rng.randint(0, 15)
        intervals: List[List[int]] = []
        cur = 0
        for _ in range(n):
            cur += rng.randint(1, 4)
            length = rng.randint(0, 4)
            intervals.append([cur, cur + length])
            cur += length

        # Random new interval anywhere in the range
        a = rng.randint(0, 60)
        b = rng.randint(a, a + 10)
        new = [a, b]

        r1 = insert_interval([list(x) for x in intervals], list(new))
        r2 = insert_interval_lazy([list(x) for x in intervals], list(new))
        r3 = insert_interval_binary([list(x) for x in intervals], list(new))

        assert r1 == r2 == r3, f"trial {trial}: r1={r1} r2={r2} r3={r3} input={intervals} new={new}"

    print("stress_test: 1000 random trials — three-phase, lazy, binary all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
