"""
p04 — Merge Sorted Array
========================
LeetCode 88 · Easy · Topics: Array, Two Pointers
"""

import random
from typing import List


# ----------------------------------------------------------------------
# 1. BRUTE FORCE — append and sort. O((M+N) log (M+N)).
# ----------------------------------------------------------------------
def merge_brute(nums1: List[int], m: int, nums2: List[int], n: int) -> None:
    nums1[m:] = nums2  # overwrite trailing zeros with nums2 contents
    nums1.sort()


# ----------------------------------------------------------------------
# 2. OPTIMAL — reverse merge in place. O(M+N) time, O(1) extra space.
# ----------------------------------------------------------------------
def merge(nums1: List[int], m: int, nums2: List[int], n: int) -> None:
    # INVARIANT throughout the loop:
    #   nums1[write+1 .. m+n-1] holds the largest (m+n-1 - write) elements, sorted.
    #   nums1[0..i] and nums2[0..j] are the unprocessed sorted prefixes.
    #   write == i + j + 1  (always), so nums1[write] is guaranteed unused.
    i = m - 1
    j = n - 1
    write = m + n - 1

    # Loop while nums2 has unprocessed elements. If nums1's valid prefix runs out
    # before nums2, we must continue. If nums2 runs out before nums1, the remaining
    # nums1[0..i] is already in place — no work needed.
    while j >= 0:
        if i >= 0 and nums1[i] > nums2[j]:
            nums1[write] = nums1[i]
            i -= 1
        else:
            nums1[write] = nums2[j]
            j -= 1
        write -= 1


# ----------------------------------------------------------------------
# 3. STRESS TEST
# ----------------------------------------------------------------------
def stress_test(iterations: int = 1000, max_size: int = 30) -> None:
    rng = random.Random(42)
    for _ in range(iterations):
        m = rng.randint(0, max_size)
        n = rng.randint(0, max_size)
        a = sorted(rng.randint(-50, 50) for _ in range(m))
        b = sorted(rng.randint(-50, 50) for _ in range(n))

        # Two parallel nums1 buffers (size m+n) — one for brute, one for optimal.
        nums1_brute = a + [0] * n
        nums1_opt = a + [0] * n

        merge_brute(nums1_brute, m, list(b), n)
        merge(nums1_opt, m, list(b), n)

        assert nums1_brute == nums1_opt, (
            f"disagree on m={m},n={n},a={a},b={b}: brute={nums1_brute}, opt={nums1_opt}"
        )
    print(f"stress_test PASSED — {iterations} iterations")


# ----------------------------------------------------------------------
# 4. EDGE CASES
# ----------------------------------------------------------------------
def edge_cases() -> None:
    cases = [
        (([1, 2, 3, 0, 0, 0], 3, [2, 5, 6], 3), [1, 2, 2, 3, 5, 6], "canonical"),
        (([1], 1, [], 0), [1], "n=0: nums2 empty"),
        (([0], 0, [1], 1), [1], "m=0: nums1 has no valid data"),
        (([4, 5, 6, 0, 0, 0], 3, [1, 2, 3], 3), [1, 2, 3, 4, 5, 6], "all-of-nums2-smaller"),
        (([1, 2, 3, 0, 0, 0], 3, [4, 5, 6], 3), [1, 2, 3, 4, 5, 6], "all-of-nums2-larger"),
        (([2, 2, 0, 0], 2, [2, 2], 2), [2, 2, 2, 2], "all duplicates"),
        (([0, 0, 0], 0, [1, 2, 3], 3), [1, 2, 3], "m=0 with nums2 fills"),
    ]
    for (nums1, m, nums2, n), expected, label in cases:
        nums1_copy = list(nums1)
        merge(nums1_copy, m, list(nums2), n)
        status = "PASS" if nums1_copy == expected else "FAIL"
        print(f"  [{status}] {label}: → {nums1_copy} (expected {expected})")


if __name__ == "__main__":
    print("=== Edge cases ===")
    edge_cases()
    print("\n=== Stress test ===")
    stress_test()
