"""
p94 — Repeated DNA Sequences (LC 187)

INVARIANT (bit-encoded): each base maps to 2 bits. Rolling 20-bit integer h:
  h_new = ((h << 2) | code[c]) & mask, mask = (1 << 20) - 1. After feeding
  10 chars, h uniquely identifies the current 10-mer.

INVARIANT (polynomial): h = sum_{j=0..k-1} s[i-k+1+j] * BASE^(k-1-j)  mod MOD.
  Slide: h = (h * BASE - s[i-k] * BASE^k + s[i]) mod MOD.
"""
from __future__ import annotations

import random
from typing import List


K = 10


def find_repeated_set(s: str) -> List[str]:
    """Naive: hash all length-K substrings. O(n*k) time/memory."""
    seen: set[str] = set()
    repeated: set[str] = set()
    for i in range(len(s) - K + 1):
        sub = s[i : i + K]
        if sub in seen:
            repeated.add(sub)
        else:
            seen.add(sub)
    return sorted(repeated)


_CODE = {"A": 0, "C": 1, "G": 2, "T": 3}


def find_repeated_bit(s: str) -> List[str]:
    """Bit-encoded rolling: 2 bits per base, 20-bit window. O(n)."""
    if len(s) < K:
        return []
    mask = (1 << (2 * K)) - 1
    h = 0
    seen: set[int] = set()
    repeated_hashes: set[int] = set()
    repeated: set[str] = set()
    for i, c in enumerate(s):
        h = ((h << 2) | _CODE[c]) & mask
        if i >= K - 1:
            if h in seen:
                if h not in repeated_hashes:
                    repeated_hashes.add(h)
                    repeated.add(s[i - K + 1 : i + 1])
            else:
                seen.add(h)
    return sorted(repeated)


def find_repeated_polynomial(s: str) -> List[str]:
    """Polynomial rolling hash + verify. O(n) expected."""
    if len(s) < K:
        return []
    BASE = 257
    MOD = (1 << 61) - 1
    base_k = pow(BASE, K, MOD)
    h = 0
    seen: dict[int, list[int]] = {}  # hash -> list of starting indices
    repeated: set[str] = set()
    for i, c in enumerate(s):
        h = (h * BASE + ord(c)) % MOD
        if i >= K:
            h = (h - ord(s[i - K]) * base_k) % MOD
        if i >= K - 1:
            start = i - K + 1
            cur = s[start : i + 1]
            if h in seen:
                for prev in seen[h]:
                    if s[prev : prev + K] == cur:
                        repeated.add(cur)
                        break
                seen[h].append(start)
            else:
                seen[h] = [start]
    return sorted(repeated)


def find_repeated_brute(s: str) -> List[str]:
    """Oracle."""
    counts: dict[str, int] = {}
    for i in range(len(s) - K + 1):
        sub = s[i : i + K]
        counts[sub] = counts.get(sub, 0) + 1
    return sorted([k for k, v in counts.items() if v > 1])


def edge_cases() -> None:
    s = "AAAAACCCCCAAAAACCCCCCAAAAAGGGTTT"
    expected = sorted(["AAAAACCCCC", "CCCCCAAAAA"])
    assert find_repeated_set(s) == expected
    assert find_repeated_bit(s) == expected
    assert find_repeated_polynomial(s) == expected

    # Too short.
    assert find_repeated_bit("AAAA") == []
    # Exactly K with no repeat possible.
    assert find_repeated_bit("AAAAAAAAAA") == []
    # All same — single repeated substring.
    assert find_repeated_bit("A" * 11) == ["AAAAAAAAAA"]


def stress_test() -> None:
    rng = random.Random(42)
    alphabet = "ACGT"
    for _ in range(200):
        n = rng.randint(0, 40)
        s = "".join(rng.choice(alphabet) for _ in range(n))
        b = find_repeated_brute(s)
        assert find_repeated_set(s) == b
        assert find_repeated_bit(s) == b
        assert find_repeated_polynomial(s) == b, (s, find_repeated_polynomial(s), b)


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
