#!/usr/bin/env python3
"""
Phoenix Zero — Node Comparison Tool

Measures your own RPC latency and compares it to Phoenix Zero feed.
Shows how early Phoenix Zero detects L2 network spikes.

Usage:
    python compare.py
    python compare.py --rpc https://arb1.arbitrum.io/rpc --chain arbitrum
    python compare.py --rpc https://mainnet.base.org --chain base --time 300

Dashboard: https://phoenix-zero.vercel.app
"""

import argparse, time, json, urllib.request, ssl, statistics
from datetime import datetime, timezone

PHOENIX_FEED = "https://rtt.phoenix-ai.work/api/public-feed"
SPIKE_MS     = 150    # spike threshold in ms
PING_N       = 8      # pings per measurement round

DEFAULT_RPC = {
    "arbitrum": "https://arb1.arbitrum.io/rpc",
    "optimism": "https://mainnet.optimism.io",
    "base":     "https://mainnet.base.org",
    "zksync":   "https://mainnet.era.zksync.io",
}
CHAIN_KEY = {
    "arbitrum": "arb_p99",
    "optimism": "op_p99",
    "base":     "base_p99",
    "zksync":   "zk_p99",
}


def _ctx():
    return ssl.create_default_context()


def fetch_feed() -> dict:
    req = urllib.request.Request(
        PHOENIX_FEED,
        headers={"User-Agent": "phoenix-compare/1.0"},
    )
    with urllib.request.urlopen(req, context=_ctx(), timeout=10) as r:
        return json.loads(r.read())


def measure_rtt(rpc_url: str) -> float | None:
    body = json.dumps({
        "jsonrpc": "2.0", "id": 1,
        "method": "eth_blockNumber", "params": [],
    }).encode()
    samples = []
    for _ in range(PING_N):
        try:
            req = urllib.request.Request(
                rpc_url, data=body,
                headers={
                    "Content-Type": "application/json",
                    "User-Agent": "phoenix-compare/1.0",
                },
            )
            t0 = time.perf_counter()
            with urllib.request.urlopen(req, context=_ctx(), timeout=5) as r:
                r.read()
            samples.append((time.perf_counter() - t0) * 1000)
        except Exception:
            pass
        time.sleep(0.25)
    if len(samples) < 3:
        return None
    s = sorted(samples)
    return round(s[min(int(len(s) * 0.99), len(s) - 1)], 1)


def utc(ts: float) -> str:
    return datetime.fromtimestamp(ts, tz=timezone.utc).strftime("%H:%M UTC")


def now_utc() -> str:
    return datetime.now(timezone.utc).strftime("%H:%M:%S UTC")


def run(rpc_url: str, chain: str, duration_s: int):
    key = CHAIN_KEY[chain]

    print()
    print("  ┌─────────────────────────────────────────────────────┐")
    print("  │         Phoenix Zero — Node Comparison Tool         │")
    print("  └─────────────────────────────────────────────────────┘")
    print(f"  Chain  : {chain.upper()}")
    print(f"  Your RPC: {rpc_url}")
    print(f"  Feed   : {PHOENIX_FEED}  (5-min delayed, public tier)")
    print(f"  Running for {duration_s}s, sampling every 60s")
    print()

    my_history = []
    start      = time.time()
    round_n    = 0

    while True:
        round_n += 1
        t = now_utc()

        # ── measure your node ──────────────────────────────────────
        print(f"  [{t}] Measuring your node ({PING_N} pings)… ", end="", flush=True)
        my_p99 = measure_rtt(rpc_url)
        if my_p99 is None:
            print("FAILED — check your RPC URL")
        else:
            flag = " ⚠ SPIKE" if my_p99 > SPIKE_MS else ""
            print(f"p99 = {my_p99:.1f} ms{flag}")
            my_history.append({"ts": time.time(), "p99": my_p99})

        # ── fetch Phoenix Zero feed ────────────────────────────────
        print(f"  [{t}] Phoenix Zero feed… ", end="", flush=True)
        try:
            feed   = fetch_feed()
            data   = feed.get("data", [])
            spikes = [b for b in data if (b.get(key) or 0) > SPIKE_MS]
            vals   = [b[key] for b in data if b.get(key)]
            if not vals:
                print("warming up (<6 min since deploy)")
            elif spikes:
                latest = spikes[-1]
                print(
                    f"{len(spikes)} spike(s) detected | "
                    f"last spike: {utc(latest['ts'])} → {latest[key]:.0f} ms"
                )
            else:
                avg = statistics.mean(vals)
                print(f"calm — avg p99={avg:.1f} ms across {len(vals)} min")
        except Exception as e:
            print(f"error: {e}")

        print()

        elapsed = time.time() - start
        if elapsed + 60 >= duration_s:
            break
        time.sleep(60)

    # ── final report ───────────────────────────────────────────────
    print("  ─────────────────────────────────────────────────────")
    print("  REPORT")
    print("  ─────────────────────────────────────────────────────")
    try:
        feed   = fetch_feed()
        data   = feed.get("data", [])
        vals   = [b[key] for b in data if b.get(key)]
        spikes = [b for b in data if (b.get(key) or 0) > SPIKE_MS]

        if vals:
            peak_idx = vals.index(max(vals))
            print(f"  Phoenix Zero — last ~60 min ({chain}):")
            print(f"    avg p99 : {statistics.mean(vals):.1f} ms")
            print(f"    max p99 : {max(vals):.0f} ms  @ {utc(data[peak_idx]['ts'])}")
            print(f"    spikes  : {len(spikes)} events > {SPIKE_MS} ms")

        if spikes:
            print()
            print(f"  Spike timeline (Phoenix Zero, 5-min delayed):")
            for s in spikes[-6:]:
                print(f"    {utc(s['ts'])}  →  {s[key]:.0f} ms")
            print()
            print(f"  Real-time tier subscribers saw these spikes")
            print(f"  5 minutes earlier — before they hit your node.")
    except Exception:
        pass

    if my_history:
        my_avg = statistics.mean(h["p99"] for h in my_history)
        print()
        print(f"  Your node — this session:")
        print(f"    rounds  : {len(my_history)}")
        print(f"    avg p99 : {my_avg:.1f} ms")

    print()
    print("  ─────────────────────────────────────────────────────")
    print("  Want real-time access? Open an issue on GitHub:")
    print("  https://github.com/kant19801201behax5/phoenix-zero-public")
    print()
    print("  Free demo key available. Searcher tier: $199/mo")
    print("  wss://rtt.phoenix-ai.work/ws  (auth on connect)")
    print("  ─────────────────────────────────────────────────────")
    print()


if __name__ == "__main__":
    ap = argparse.ArgumentParser(
        description="Compare your RPC node latency to Phoenix Zero feed"
    )
    ap.add_argument("--chain", default="arbitrum",
                    choices=list(CHAIN_KEY), help="L2 chain (default: arbitrum)")
    ap.add_argument("--rpc",   default=None,
                    help="Your RPC endpoint URL (default: public endpoint for chain)")
    ap.add_argument("--time",  default=180, type=int,
                    help="Duration in seconds (default: 180)")
    args = ap.parse_args()

    rpc = args.rpc or DEFAULT_RPC[args.chain]
    run(rpc, args.chain, args.time)