"""
p44 — Merge K Sorted Lists (LeetCode 23, HARD).

Three implementations:
    merge_k_lists_heap            — heap of K heads. O(N log K).
    merge_k_lists_divide_conquer  — iterative pairwise merge. O(N log K).
    merge_k_lists_brute           — collect + sort baseline.

Stress: 200 random trials, all three agree as arrays.
"""
from __future__ import annotations
import heapq
import random
from typing import List, Optional


class ListNode:
    __slots__ = ("val", "next")

    def __init__(self, val: int = 0, nxt: "Optional[ListNode]" = None) -> None:
        self.val = val
        self.next = nxt


def to_list_from_array(arr: List[int]) -> Optional[ListNode]:
    dummy = ListNode()
    tail = dummy
    for x in arr:
        tail.next = ListNode(x)
        tail = tail.next
    return dummy.next


def to_array_from_list(head: Optional[ListNode]) -> List[int]:
    out: List[int] = []
    while head:
        out.append(head.val)
        head = head.next
    return out


# --------------------------------------------------------------------------------------
# Heap of K head pointers. O(N log K).
# Tie-break with `i` to avoid comparing ListNode objects on equal values.
# --------------------------------------------------------------------------------------
def merge_k_lists_heap(lists: List[Optional[ListNode]]) -> Optional[ListNode]:
    heap: List[tuple] = []
    for i, head in enumerate(lists):
        if head is not None:
            heapq.heappush(heap, (head.val, i, head))

    dummy = ListNode()
    tail = dummy
    while heap:
        val, i, node = heapq.heappop(heap)
        tail.next = node
        tail = node
        if node.next is not None:
            heapq.heappush(heap, (node.next.val, i, node.next))
    tail.next = None   # detach trailing pointer
    return dummy.next


# --------------------------------------------------------------------------------------
# Divide and conquer iterative pairwise merge. O(N log K).
# --------------------------------------------------------------------------------------
def merge_two_lists(a: Optional[ListNode], b: Optional[ListNode]) -> Optional[ListNode]:
    dummy = ListNode()
    tail = dummy
    while a and b:
        if a.val <= b.val:
            tail.next = a
            a = a.next
        else:
            tail.next = b
            b = b.next
        tail = tail.next
    tail.next = a if a else b
    return dummy.next


def merge_k_lists_divide_conquer(lists: List[Optional[ListNode]]) -> Optional[ListNode]:
    if not lists:
        return None
    while len(lists) > 1:
        merged: List[Optional[ListNode]] = []
        for i in range(0, len(lists), 2):
            l1 = lists[i]
            l2 = lists[i + 1] if i + 1 < len(lists) else None
            merged.append(merge_two_lists(l1, l2))
        lists = merged
    return lists[0]


# --------------------------------------------------------------------------------------
# Brute: collect all values, sort, rebuild. O(N log N).
# --------------------------------------------------------------------------------------
def merge_k_lists_brute(lists: List[Optional[ListNode]]) -> Optional[ListNode]:
    vals: List[int] = []
    for head in lists:
        while head:
            vals.append(head.val)
            head = head.next
    vals.sort()
    return to_list_from_array(vals)


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    # LC canonical
    lists = [to_list_from_array(a) for a in ([1, 4, 5], [1, 3, 4], [2, 6])]
    assert to_array_from_list(merge_k_lists_heap(lists)) == [1, 1, 2, 3, 4, 4, 5, 6]

    # Empty input array
    assert merge_k_lists_heap([]) is None
    assert merge_k_lists_divide_conquer([]) is None

    # All-empty lists
    assert merge_k_lists_heap([None, None, None]) is None
    assert merge_k_lists_divide_conquer([None, None, None]) is None

    # Single list
    h = to_list_from_array([1, 2, 3])
    assert to_array_from_list(merge_k_lists_heap([h])) == [1, 2, 3]

    # All-equal values (tie-break stress)
    lists = [to_list_from_array([5, 5, 5]) for _ in range(4)]
    assert to_array_from_list(merge_k_lists_heap(lists)) == [5] * 12

    # Mixed empty/non-empty
    lists = [None, to_list_from_array([1, 2]), None, to_list_from_array([0, 3])]
    assert to_array_from_list(merge_k_lists_divide_conquer(lists)) == [0, 1, 2, 3]

    # Negatives
    lists = [to_list_from_array([-5, 0, 5]), to_list_from_array([-3, -1, 2])]
    assert to_array_from_list(merge_k_lists_heap(lists)) == [-5, -3, -1, 0, 2, 5]

    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for _ in range(200):
        k = rng.randint(0, 20)
        arrs = []
        for _ in range(k):
            n = rng.randint(0, 30)
            arr = sorted(rng.randint(-50, 50) for _ in range(n))
            arrs.append(arr)
        # Build three independent copies (linked lists are mutated in place).
        lists1 = [to_list_from_array(a) for a in arrs]
        lists2 = [to_list_from_array(a) for a in arrs]
        lists3 = [to_list_from_array(a) for a in arrs]

        a1 = to_array_from_list(merge_k_lists_heap(lists1))
        a2 = to_array_from_list(merge_k_lists_divide_conquer(lists2))
        a3 = to_array_from_list(merge_k_lists_brute(lists3))

        assert a1 == a2 == a3, f"mismatch on arrs={arrs}"

    print("stress_test: 200 random trials — heap, D&C, brute all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
