"""
p105 — All O(1) Data Structure (LC 432)

A doubly-linked list of (count, set_of_keys) buckets, ordered ASCENDING by count.
  head <-> bucket(c_min) <-> ... <-> bucket(c_max) <-> tail

Plus:
  key_to_count: dict[str, int]
  count_to_bucket: dict[int, Bucket]  — O(1) bucket lookup by count

INVARIANT 1: Every non-sentinel bucket has count > 0 and at least one key.
INVARIANT 2: Buckets in the DLL are strictly increasing in count from head to tail.
INVARIANT 3: For every key k with count c > 0, k ∈ count_to_bucket[c].keys.

All four ops O(1).
"""
from __future__ import annotations
import random
from typing import Dict, Optional, Set


class _Bucket:
    __slots__ = ("count", "keys", "prev", "next")

    def __init__(self, count: int) -> None:
        self.count = count
        self.keys: Set[str] = set()
        self.prev: Optional["_Bucket"] = None
        self.next: Optional["_Bucket"] = None


class AllOne:
    def __init__(self) -> None:
        # Sentinel head/tail. Sentinels carry no real count.
        self.head = _Bucket(0)
        self.tail = _Bucket(0)
        self.head.next = self.tail
        self.tail.prev = self.head
        self.key_to_count: Dict[str, int] = {}
        self.count_to_bucket: Dict[int, _Bucket] = {}

    # ---- DLL helpers --------------------------------------------------------
    def _insert_after(self, anchor: _Bucket, new_bucket: _Bucket) -> None:
        nxt = anchor.next
        new_bucket.prev = anchor
        new_bucket.next = nxt
        anchor.next = new_bucket
        nxt.prev = new_bucket

    def _unlink(self, bucket: _Bucket) -> None:
        bucket.prev.next = bucket.next
        bucket.next.prev = bucket.prev
        bucket.prev = None
        bucket.next = None

    def _drop_if_empty(self, bucket: _Bucket) -> None:
        if not bucket.keys:
            del self.count_to_bucket[bucket.count]
            self._unlink(bucket)

    # ---- Public API ---------------------------------------------------------
    def inc(self, key: str) -> None:
        if key in self.key_to_count:
            old_c = self.key_to_count[key]
            new_c = old_c + 1
            old_bucket = self.count_to_bucket[old_c]
            new_bucket = self.count_to_bucket.get(new_c)
            if new_bucket is None:
                # Create immediately AFTER old_bucket (counts are increasing toward tail).
                new_bucket = _Bucket(new_c)
                self._insert_after(old_bucket, new_bucket)
                self.count_to_bucket[new_c] = new_bucket
            new_bucket.keys.add(key)
            old_bucket.keys.discard(key)
            self.key_to_count[key] = new_c
            self._drop_if_empty(old_bucket)
        else:
            self.key_to_count[key] = 1
            bucket1 = self.count_to_bucket.get(1)
            if bucket1 is None:
                bucket1 = _Bucket(1)
                # Count-1 bucket goes at the head.
                self._insert_after(self.head, bucket1)
                self.count_to_bucket[1] = bucket1
            bucket1.keys.add(key)

    def dec(self, key: str) -> None:
        # Per LC: key is guaranteed to exist (count > 0).
        old_c = self.key_to_count[key]
        old_bucket = self.count_to_bucket[old_c]
        if old_c == 1:
            old_bucket.keys.discard(key)
            del self.key_to_count[key]
            self._drop_if_empty(old_bucket)
            return
        new_c = old_c - 1
        new_bucket = self.count_to_bucket.get(new_c)
        if new_bucket is None:
            # Create immediately BEFORE old_bucket.
            new_bucket = _Bucket(new_c)
            self._insert_after(old_bucket.prev, new_bucket)
            self.count_to_bucket[new_c] = new_bucket
        new_bucket.keys.add(key)
        old_bucket.keys.discard(key)
        self.key_to_count[key] = new_c
        self._drop_if_empty(old_bucket)

    def getMaxKey(self) -> str:
        if self.tail.prev is self.head:
            return ""
        # Any key in the highest-count bucket.
        return next(iter(self.tail.prev.keys))

    def getMinKey(self) -> str:
        if self.head.next is self.tail:
            return ""
        return next(iter(self.tail.prev.keys)) if self.tail.prev.keys else "" 
    def getMinKey(self) -> str: 
        if self.head.next is self.tail: 
            return "" 
        return next(iter(self.head.next.keys)) if self.head.next.keys else ""
# -----------------------------------------------------------------------------
# Brute oracle: dict + scan
# -----------------------------------------------------------------------------
class AllOneBrute:
    def __init__(self) -> None:
        self.counts: Dict[str, int] = {}

    def inc(self, key: str) -> None:
        self.counts[key] = self.counts.get(key, 0) + 1

    def dec(self, key: str) -> None:
        self.counts[key] -= 1
        if self.counts[key] == 0:
            del self.counts[key]

    def getMaxKey(self) -> str:
        if not self.counts:
            return ""
        return max(self.counts, key=lambda k: self.counts[k])

    def getMinKey(self) -> str:
        if not self.counts:
            return ""
        return min(self.counts, key=lambda k: self.counts[k])


# -----------------------------------------------------------------------------
# Edge cases
# -----------------------------------------------------------------------------
def edge_cases() -> None:
    a = AllOne()
    assert a.getMaxKey() == ""
    assert a.getMinKey() == ""

    a.inc("hello")
    a.inc("hello")
    assert a.getMaxKey() == "hello"
    assert a.getMinKey() == "hello"

    a.inc("world")
    assert a.getMaxKey() == "hello"
    assert a.getMinKey() == "world"

    a.inc("world")
    # Both at count 2 in the same bucket; getMax/getMin may return any (LC: "any one").
    assert a.getMaxKey() in {"hello", "world"}
    assert a.getMinKey() in {"hello", "world"}

    a.dec("hello")     # hello -> 1, world stays 2
    assert a.getMaxKey() == "world"
    assert a.getMinKey() == "hello"

    a.dec("hello")     # remove hello
    assert a.getMaxKey() == "world"
    assert a.getMinKey() == "world"

    a.dec("world"); a.dec("world")  # remove world
    assert a.getMaxKey() == ""
    assert a.getMinKey() == ""

    # Bucket between existing ones
    b = AllOne()
    b.inc("a"); b.inc("a"); b.inc("a")           # a:3
    b.inc("c")                                    # c:1
    b.inc("b"); b.inc("b")                        # b:2 — must create bucket(2) BETWEEN c(1) and a(3)
    assert b.getMinKey() == "c"
    assert b.getMaxKey() == "a"
    b.dec("a"); b.dec("a")                        # a:1 — moves to bucket(1) with c
    assert b.getMaxKey() == "b"
    assert b.getMinKey() in ("a", "c")

    print("edge_cases OK")


# -----------------------------------------------------------------------------
# Stress
# -----------------------------------------------------------------------------
def stress_test(trials: int = 100, ops: int = 300) -> None:
    rng = random.Random(42)
    for trial in range(trials):
        a = AllOne()
        b = AllOneBrute()
        # Track live keys to honor LC's "dec only on existing key" precondition.
        live: Set[str] = set()
        for _ in range(ops):
            op = rng.choices(["inc", "dec", "max", "min"], weights=[5, 3, 2, 2])[0]
            if op == "inc":
                key = rng.choice("abcdef")
                a.inc(key); b.inc(key); live.add(key)
            elif op == "dec":
                if not live:
                    continue
                key = rng.choice(sorted(live))
                a.dec(key); b.dec(key)
                if b.counts.get(key, 0) == 0:
                    live.discard(key)
            elif op == "max":
                ra, rb = a.getMaxKey(), b.getMaxKey()
                # Multiple keys can share the max count. Verify the returned key
                # actually has the max count.
                if rb == "":
                    assert ra == ""
                else:
                    max_c = b.counts[rb]
                    assert ra != ""
                    assert b.counts.get(ra) == max_c, (trial, ra, rb, b.counts)
            else:
                ra, rb = a.getMinKey(), b.getMinKey()
                if rb == "":
                    assert ra == ""
                else:
                    min_c = b.counts[rb]
                    assert ra != ""
                    assert b.counts.get(ra) == min_c, (trial, ra, rb, b.counts)
    print(f"stress_test OK ({trials} trials)")


if __name__ == "__main__":
    edge_cases()
    stress_test()
