"""
p93 — Implement strStr() / KMP (LC 28)

INVARIANT (lps): lps[i] = length of the longest proper prefix of needle[0..i]
  that is also a suffix.

INVARIANT (match loop): i (haystack pointer) NEVER decreases. j (needle
  pointer) falls back to lps[j-1] on mismatch. Amortized O(n+m).
"""
from __future__ import annotations

import random
import string
from typing import List


def _build_lps(needle: str) -> List[int]:
    m = len(needle)
    lps = [0] * m
    length = 0  # length of previous longest prefix-suffix
    i = 1
    while i < m:
        if needle[i] == needle[length]:
            length += 1
            lps[i] = length
            i += 1
        else:
            if length > 0:
                length = lps[length - 1]
            else:
                lps[i] = 0
                i += 1
    return lps


def strstr_kmp(haystack: str, needle: str) -> int:
    """KMP. O(n + m)."""
    if needle == "":
        return 0
    n, m = len(haystack), len(needle)
    if m > n:
        return -1
    lps = _build_lps(needle)
    i = j = 0
    while i < n:
        if haystack[i] == needle[j]:
            i += 1
            j += 1
            if j == m:
                return i - m
        else:
            if j > 0:
                j = lps[j - 1]
            else:
                i += 1
    return -1


def strstr_brute(haystack: str, needle: str) -> int:
    """Oracle: naive O(n*m)."""
    if needle == "":
        return 0
    n, m = len(haystack), len(needle)
    for i in range(n - m + 1):
        if haystack[i : i + m] == needle:
            return i
    return -1


def edge_cases() -> None:
    assert strstr_kmp("sadbutsad", "sad") == 0
    assert strstr_kmp("leetcode", "leeto") == -1
    assert strstr_kmp("hello", "ll") == 2

    # Empty needle.
    assert strstr_kmp("abc", "") == 0
    # Needle longer than haystack.
    assert strstr_kmp("a", "abc") == -1
    # Needle == haystack.
    assert strstr_kmp("abc", "abc") == 0
    # Periodic needle.
    assert strstr_kmp("aaaaab", "aaab") == 2
    # Match at very end.
    assert strstr_kmp("xyzabc", "abc") == 3
    # No occurrence with partial prefix matches.
    assert strstr_kmp("ababababc", "ababc") == 4

    # lps spot checks.
    assert _build_lps("aabaaab") == [0, 1, 0, 1, 2, 2, 3]
    assert _build_lps("abc") == [0, 0, 0]
    assert _build_lps("aaaa") == [0, 1, 2, 3]


def stress_test() -> None:
    rng = random.Random(42)
    alphabet = "ab"
    for _ in range(500):
        n = rng.randint(0, 12)
        m = rng.randint(0, 5)
        haystack = "".join(rng.choice(alphabet) for _ in range(n))
        needle = "".join(rng.choice(alphabet) for _ in range(m))
        k = strstr_kmp(haystack, needle)
        b = strstr_brute(haystack, needle)
        assert k == b, (haystack, needle, k, b)


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
