"""
p104 — Design Autocomplete System (LC 642)

API:
  input(c):
    - if c == '#': commit self.buf to trie (count += 1; insert if new); reset; return [].
    - else: append c to self.buf; descend self.cur if possible; return top-3 sentences
            whose prefix equals self.buf, ranked by (-frequency, sentence).
            If off-trie, return [] and stay off-trie until '#'.

INVARIANT A: self.cur is either the trie node for self.buf, or None (off-trie).
INVARIANT B: A node's `count` is positive iff a sentence ends exactly at that node.
INVARIANT C: '#' is the only character that resets state.

Two strategies for top-3:
  A. DFS the subtree rooted at self.cur, collect all (count, sentence), sort, take 3.
     Simple; O(subtree). Used here.
  B. Maintain a top-3 list at each node on insert. O(L * K log K) insert; O(1) query.
     Production-grade; not implemented here for clarity.
"""
from __future__ import annotations
import random
from typing import Dict, List, Optional, Tuple


class _Node:
    __slots__ = ("children", "count", "sentence")

    def __init__(self) -> None:
        self.children: Dict[str, "_Node"] = {}
        self.count: int = 0
        self.sentence: str = ""  # full sentence stored at terminal; "" elsewhere


class AutocompleteSystem:
    def __init__(self, sentences: List[str], times: List[int]):
        self.root = _Node()
        for s, c in zip(sentences, times):
            self._insert(s, c)
        self.buf: str = ""
        self.cur: Optional[_Node] = self.root

    def _insert(self, sentence: str, freq: int) -> None:
        node = self.root
        for ch in sentence:
            nxt = node.children.get(ch)
            if nxt is None:
                nxt = _Node()
                node.children[ch] = nxt
            node = nxt
        node.count += freq
        node.sentence = sentence

    def _collect(self, node: _Node) -> List[Tuple[int, str]]:
        """DFS subtree collecting (count, sentence) for every sentence-terminal node."""
        out: List[Tuple[int, str]] = []
        stack = [node]
        while stack:
            n = stack.pop()
            if n.count > 0:
                out.append((n.count, n.sentence))
            stack.extend(n.children.values())
        return out

    def input(self, c: str) -> List[str]:
        if c == "#":
            self._insert(self.buf, 1)
            self.buf = ""
            self.cur = self.root
            return []
        self.buf += c
        if self.cur is not None:
            self.cur = self.cur.children.get(c)  # may be None -> off-trie
        if self.cur is None:
            return []
        candidates = self._collect(self.cur)
        candidates.sort(key=lambda x: (-x[0], x[1]))
        return [s for _, s in candidates[:3]]


# -----------------------------------------------------------------------------
# Brute oracle: keep a dict of sentence -> count; linear scan per input.
# -----------------------------------------------------------------------------
class AutocompleteSystemBrute:
    def __init__(self, sentences: List[str], times: List[int]):
        self.counts: Dict[str, int] = {}
        for s, c in zip(sentences, times):
            self.counts[s] = self.counts.get(s, 0) + c
        self.buf: str = ""

    def input(self, c: str) -> List[str]:
        if c == "#":
            self.counts[self.buf] = self.counts.get(self.buf, 0) + 1
            self.buf = ""
            return []
        self.buf += c
        prefix = self.buf
        cands = [(cnt, s) for s, cnt in self.counts.items() if s.startswith(prefix)]
        cands.sort(key=lambda x: (-x[0], x[1]))
        return [s for _, s in cands[:3]]


# -----------------------------------------------------------------------------
# Edge cases
# -----------------------------------------------------------------------------
def edge_cases() -> None:
    # LC example
    a = AutocompleteSystem(
        ["i love you", "island", "ironman", "i love leetcode"],
        [5, 3, 2, 2],
    )
    assert a.input("i") == ["i love you", "island", "i love leetcode"]
    assert a.input(" ") == ["i love you", "i love leetcode"]
    assert a.input("a") == []          # off-trie
    assert a.input("#") == []          # commit "i a"
    # Now "i a" exists with count 1; "i love you" still tops "i"
    assert a.input("i") == ["i love you", "island", "i love leetcode"]
    assert a.input(" ") == ["i love you", "i love leetcode", "i a"]
    assert a.input("a") == ["i a"]     # back on-trie because "i a" inserted
    assert a.input("#") == []

    # Tie-break by sentence asc
    b = AutocompleteSystem(["bbb", "aaa", "ccc"], [1, 1, 1])
    assert b.input("a") == ["aaa"]
    assert b.input("#") == []
    b2 = AutocompleteSystem(["bb", "aa"], [1, 1])
    assert b2.input("a") == ["aa"]
    b2.input("#")
    # Now ranking across all of them:
    b3 = AutocompleteSystem(["x", "y", "z"], [3, 3, 3])
    assert b3.input("y") == ["y"]
    b3.input("#")

    # Empty history
    e = AutocompleteSystem([], [])
    assert e.input("h") == []
    assert e.input("i") == []
    assert e.input("#") == []
    assert e.input("h") == ["hi"]
    assert e.input("i") == ["hi"]

    print("edge_cases OK")


# -----------------------------------------------------------------------------
# Stress
# -----------------------------------------------------------------------------
def _rand_sentence(rng: random.Random) -> str:
    n = rng.randint(1, 6)
    return "".join(rng.choice("abc ") for _ in range(n))


def stress_test(trials: int = 60, ops: int = 80) -> None:
    rng = random.Random(42)
    for trial in range(trials):
        n_init = rng.randint(0, 6)
        sents = [_rand_sentence(rng) for _ in range(n_init)]
        times = [rng.randint(1, 5) for _ in range(n_init)]
        a = AutocompleteSystem(list(sents), list(times))
        b = AutocompleteSystemBrute(list(sents), list(times))
        for _ in range(ops):
            # Stream a few chars from the legal alphabet, occasionally '#'
            if rng.random() < 0.15:
                ch = "#"
            else:
                ch = rng.choice("abc ")
            ra = a.input(ch)
            rb = b.input(ch)
            assert ra == rb, (trial, ch, ra, rb)
    print(f"stress_test OK ({trials} trials)")


if __name__ == "__main__":
    edge_cases()
    stress_test()
