"""
palace-discover: Derive spatial structure from a vector store.

Reads a ChromaDB collection, runs hierarchical clustering (HDBSCAN),
projects to 3D (UMAP), labels clusters via Claude, and emits a
topology JSON that a Three.js viewer can render as an explorable building.

Usage:
    python discover.py --palace ~/.mempalace/palace --collection mempalace_drawers
    python discover.py --palace ./my_chroma_db --out topology.json
"""

import argparse
import json
import sys
import os
from dataclasses import dataclass, field, asdict
from pathlib import Path

import chromadb
import hdbscan
import numpy as np
import umap
from sklearn.metrics.pairwise import cosine_similarity


# ---------------------------------------------------------------------------
# Data model — the topology contract between discovery and viewer
# ---------------------------------------------------------------------------

@dataclass
class Document:
    id: str
    content_preview: str  # first ~200 chars
    metadata: dict
    position_3d: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.0])
    room_id: str | None = None

@dataclass
class Room:
    id: str
    label: str
    wing_id: str
    centroid_3d: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.0])
    document_ids: list[str] = field(default_factory=list)
    size: int = 0  # doc count — drives room volume in viewer

@dataclass
class Wing:
    id: str
    label: str
    centroid_3d: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.0])
    room_ids: list[str] = field(default_factory=list)
    size: int = 0

@dataclass
class Connection:
    source_id: str  # room or wing id
    target_id: str
    kind: str  # "hall" (intra-wing) or "tunnel" (inter-wing)
    affinity: float = 0.0  # cosine similarity between centroids

@dataclass
class Topology:
    wings: list[Wing] = field(default_factory=list)
    rooms: list[Room] = field(default_factory=list)
    documents: list[Document] = field(default_factory=list)
    connections: list[Connection] = field(default_factory=list)


# ---------------------------------------------------------------------------
# Stage 1: Extract embeddings from ChromaDB
# ---------------------------------------------------------------------------

def load_collection(palace_path: str, collection_name: str | None):
    """Pull everything from a ChromaDB collection."""
    client = chromadb.PersistentClient(path=palace_path)

    if collection_name is None:
        collections = client.list_collections()
        if not collections:
            print("  No collections found in database!")
            sys.exit(1)
        if len(collections) == 1:
            collection_name = collections[0].name
            print(f"  Auto-detected collection: '{collection_name}'")
        else:
            print("  Multiple collections found:")
            for c in collections:
                print(f"    - {c.name} ({c.count()} docs)")
            print("  Use --collection to pick one.")
            sys.exit(1)

    collection = client.get_collection(collection_name)

    # ChromaDB caps get() at internal limits; page through if needed
    total = collection.count()
    print(f"  Collection '{collection_name}' has {total} documents")

    batch_size = 5000
    all_ids, all_embeddings, all_documents, all_metadatas = [], [], [], []
    offset = 0

    while offset < total:
        batch = collection.get(
            include=["embeddings", "metadatas", "documents"],
            limit=batch_size,
            offset=offset,
        )
        all_ids.extend(batch["ids"])
        all_embeddings.extend(batch["embeddings"])
        all_documents.extend(batch["documents"] or [""] * len(batch["ids"]))
        all_metadatas.extend(batch["metadatas"] or [{}] * len(batch["ids"]))
        offset += batch_size

    embeddings = np.array(all_embeddings, dtype=np.float32)
    print(f"  Loaded {len(all_ids)} embeddings, dim={embeddings.shape[1]}")
    return all_ids, embeddings, all_documents, all_metadatas


# ---------------------------------------------------------------------------
# Stage 2: Hierarchical clustering
# ---------------------------------------------------------------------------

def cluster_hierarchical(embeddings: np.ndarray, min_cluster_size: int = 15,
                         target_wings: int | None = None):
    """
    Two-level clustering:
      - Fine (rooms): HDBSCAN on UMAP-reduced embeddings
      - Coarse (wings): Agglomerative on room centroids
    Noise points get assigned to their nearest cluster.
    """
    from sklearn.cluster import AgglomerativeClustering
    from sklearn.neighbors import NearestCentroid

    # Reduce dims before clustering — HDBSCAN works much better in ~20d
    print("  Reducing to 20d for clustering...")
    reducer = umap.UMAP(n_components=20, metric="cosine", n_neighbors=30,
                        min_dist=0.0, random_state=42)
    reduced = reducer.fit_transform(embeddings)

    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=min_cluster_size,
        min_samples=3,
        metric="euclidean",
        cluster_selection_method="eom",
    )
    fine_labels = clusterer.fit_predict(reduced)
    n_fine = len(set(fine_labels)) - (1 if -1 in fine_labels else 0)
    n_noise = int(np.sum(fine_labels == -1))
    print(f"  Fine clusters (rooms): {n_fine}, noise points: {n_noise}")

    # Assign noise to nearest cluster centroid
    if n_noise > 0 and n_fine > 0:
        cluster_ids = sorted(set(fine_labels) - {-1})
        centroids = np.array([reduced[fine_labels == c].mean(axis=0) for c in cluster_ids])
        noise_mask = fine_labels == -1
        noise_pts = reduced[noise_mask]
        # Find nearest centroid for each noise point
        from sklearn.metrics import pairwise_distances
        dists = pairwise_distances(noise_pts, centroids, metric="euclidean")
        nearest = dists.argmin(axis=1)
        fine_labels[noise_mask] = np.array([cluster_ids[n] for n in nearest])
        print(f"  Reassigned {n_noise} noise points to nearest room")

    # Coarse: agglomerative on room centroids with a reasonable wing count
    fine_cluster_ids = sorted(set(fine_labels) - {-1})

    if len(fine_cluster_ids) <= 3:
        coarse_map = {fc: 0 for fc in fine_cluster_ids}
        print(f"  Coarse clusters (wings): 1")
    else:
        centroids = np.array([
            embeddings[fine_labels == c].mean(axis=0) for c in fine_cluster_ids
        ])
        # Target: sqrt(n_rooms) wings, clamped to 5..30
        n_wings = target_wings or max(5, min(30, int(np.sqrt(len(fine_cluster_ids)))))
        n_wings = min(n_wings, len(fine_cluster_ids))
        agg = AgglomerativeClustering(
            n_clusters=n_wings, metric="cosine", linkage="average"
        )
        coarse_labels = agg.fit_predict(centroids)
        coarse_map = {fc: int(coarse_labels[i]) for i, fc in enumerate(fine_cluster_ids)}
        n_coarse = len(set(coarse_map.values()))
        print(f"  Coarse clusters (wings): {n_coarse}")

    return fine_labels, coarse_map


# ---------------------------------------------------------------------------
# Stage 3: UMAP 3D projection
# ---------------------------------------------------------------------------

def project_3d(embeddings: np.ndarray, seed: int = 42) -> np.ndarray:
    """Project high-dim embeddings to 3D for spatial layout."""
    reducer = umap.UMAP(
        n_components=3,
        metric="cosine",
        n_neighbors=min(15, len(embeddings) - 1),
        min_dist=0.1,
        random_state=seed,
    )
    coords = reducer.fit_transform(embeddings)
    # Normalize to a reasonable world-space range (0..100)
    for dim in range(3):
        lo, hi = coords[:, dim].min(), coords[:, dim].max()
        if hi - lo > 1e-6:
            coords[:, dim] = (coords[:, dim] - lo) / (hi - lo) * 100.0
    print(f"  Projected to 3D, range [0, 100]")
    return coords


# ---------------------------------------------------------------------------
# Stage 4: LLM labeling
# ---------------------------------------------------------------------------

def label_clusters_llm(
    cluster_ids: list[int],
    embeddings: np.ndarray,
    labels: np.ndarray,
    documents: list[str],
    level: str = "room",
    sample_k: int = 8,
) -> dict[int, str]:
    """Ask Claude to name each cluster from representative samples."""
    try:
        import anthropic
        client = anthropic.Anthropic()  # uses ANTHROPIC_API_KEY env
    except Exception:
        print("  ⚠ No Anthropic API key — falling back to keyword labels")
        return label_clusters_keyword(cluster_ids, labels, documents)

    result = {}
    # Batch all clusters into one prompt to save API calls
    cluster_samples = {}
    for cid in cluster_ids:
        mask = labels == cid
        docs = [documents[i] for i in np.where(mask)[0]]
        sample = docs[:sample_k]
        cluster_samples[cid] = sample

    prompt_parts = [
        f"You are labeling {level}s in a knowledge palace. "
        f"Each cluster below contains sample documents. "
        f"Give each cluster a short, evocative label (2-5 words). "
        f"Respond ONLY with a JSON object mapping cluster ID to label string.\n\n"
    ]
    for cid, samples in cluster_samples.items():
        previews = "\n".join(f"  - {s[:200]}" for s in samples)
        prompt_parts.append(f"Cluster {cid}:\n{previews}\n\n")

    try:
        resp = client.messages.create(
            model="claude-sonnet-4-20250514",
            max_tokens=1024,
            messages=[{"role": "user", "content": "".join(prompt_parts)}],
        )
        text = resp.content[0].text.strip()
        # Strip markdown fences if present
        if text.startswith("```"):
            text = text.split("\n", 1)[1].rsplit("```", 1)[0]
        parsed = json.loads(text)
        result = {int(k): v for k, v in parsed.items()}
        print(f"  Labeled {len(result)} {level}s via Claude")
    except Exception as e:
        print(f"  ⚠ LLM labeling failed ({e}) — falling back to keywords")
        return label_clusters_keyword(cluster_ids, labels, documents)

    return result


def label_clusters_keyword(
    cluster_ids: list[int],
    labels: np.ndarray,
    documents: list[str],
) -> dict[int, str]:
    """Fallback: extract top TF-IDF-ish keywords per cluster."""
    from collections import Counter
    import re

    stopwords = {
        "the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
        "have", "has", "had", "do", "does", "did", "will", "would", "could",
        "should", "may", "might", "shall", "can", "to", "of", "in", "for",
        "on", "with", "at", "by", "from", "as", "into", "through", "during",
        "before", "after", "above", "below", "between", "and", "but", "or",
        "not", "no", "nor", "so", "yet", "both", "either", "neither", "it",
        "its", "this", "that", "these", "those", "i", "me", "my", "we", "our",
        "you", "your", "he", "him", "his", "she", "her", "they", "them", "their",
        "what", "which", "who", "whom", "how", "when", "where", "why", "if",
        "then", "than", "also", "just", "about", "up", "out", "all", "more",
    }
    result = {}
    for cid in cluster_ids:
        mask = labels == cid
        text = " ".join(documents[i][:500] for i in np.where(mask)[0])
        words = re.findall(r"[a-zA-Z]{3,}", text.lower())
        words = [w for w in words if w not in stopwords]
        top = [w for w, _ in Counter(words).most_common(3)]
        result[cid] = " ".join(top) if top else f"cluster-{cid}"
    print(f"  Keyword-labeled {len(result)} clusters")
    return result


# ---------------------------------------------------------------------------
# Stage 5: Build topology
# ---------------------------------------------------------------------------

def build_topology(
    ids: list[str],
    embeddings: np.ndarray,
    documents: list[str],
    metadatas: list[dict],
    fine_labels: np.ndarray,
    coarse_map: dict[int, int],
    coords_3d: np.ndarray,
    room_labels: dict[int, str],
    wing_labels: dict[int, str],
) -> Topology:
    topo = Topology()
    room_map: dict[int, Room] = {}
    wing_map: dict[int, Wing] = {}

    # Build wings
    for wing_id in sorted(set(coarse_map.values())):
        label = wing_labels.get(wing_id, f"wing-{wing_id}")
        wing = Wing(id=f"wing_{wing_id}", label=label)
        wing_map[wing_id] = wing
        topo.wings.append(wing)

    # Build rooms
    fine_cluster_ids = sorted(set(fine_labels) - {-1})
    for fc in fine_cluster_ids:
        wid = coarse_map.get(fc, 0)
        label = room_labels.get(fc, f"room-{fc}")
        mask = fine_labels == fc
        room_coords = coords_3d[mask]
        centroid = room_coords.mean(axis=0).tolist()

        room = Room(
            id=f"room_{fc}",
            label=label,
            wing_id=f"wing_{wid}",
            centroid_3d=centroid,
            size=int(mask.sum()),
        )
        room_map[fc] = room
        topo.rooms.append(room)
        wing_map[wid].room_ids.append(room.id)
        wing_map[wid].size += room.size

    # Compute wing centroids from their rooms
    for wid, wing in wing_map.items():
        room_centroids = [
            room_map[fc].centroid_3d
            for fc in fine_cluster_ids
            if coarse_map.get(fc) == wid
        ]
        if room_centroids:
            wing.centroid_3d = np.mean(room_centroids, axis=0).tolist()

    # Build documents
    for i, doc_id in enumerate(ids):
        fc = fine_labels[i]
        room_id = f"room_{fc}" if fc != -1 else None
        doc = Document(
            id=doc_id,
            content_preview=documents[i][:200] if documents[i] else "",
            metadata=metadatas[i] or {},
            position_3d=coords_3d[i].tolist(),
            room_id=room_id,
        )
        topo.documents.append(doc)
        if fc != -1 and fc in room_map:
            room_map[fc].document_ids.append(doc_id)

    # Build connections using original embedding centroids (not 3D projections)
    # Top-K nearest rooms per room, split into halls (same wing) and tunnels (cross-wing)
    room_emb_centroids = np.array([
        embeddings[fine_labels == fc].mean(axis=0)
        for fc in fine_cluster_ids
    ])
    room_list = [room_map[fc] for fc in fine_cluster_ids]
    sims = cosine_similarity(room_emb_centroids)
    np.fill_diagonal(sims, 0)  # no self-connections

    top_k = min(5, len(room_list) - 1)
    seen = set()
    for i in range(len(room_list)):
        top_indices = np.argsort(sims[i])[-top_k:]
        for j in top_indices:
            if sims[i][j] < 0.15:
                continue
            edge = tuple(sorted((i, j)))
            if edge in seen:
                continue
            seen.add(edge)
            same_wing = room_list[i].wing_id == room_list[j].wing_id
            topo.connections.append(Connection(
                source_id=room_list[i].id,
                target_id=room_list[j].id,
                kind="hall" if same_wing else "tunnel",
                affinity=float(sims[i][j]),
            ))

    print(f"  Topology: {len(topo.wings)} wings, {len(topo.rooms)} rooms, "
          f"{len(topo.documents)} docs, {len(topo.connections)} connections")
    return topo


# ---------------------------------------------------------------------------
# Main pipeline
# ---------------------------------------------------------------------------

def discover(
    palace_path: str,
    collection_name: str | None = None,
    min_cluster_size: int = 15,
    target_wings: int | None = None,
    use_llm: bool = True,
) -> Topology:
    print("Stage 1: Loading embeddings...")
    ids, embeddings, documents, metadatas = load_collection(palace_path, collection_name)

    if len(ids) < 4:
        print("Too few documents for meaningful clustering.")
        sys.exit(1)

    print("Stage 2: Clustering...")
    fine_labels, coarse_map = cluster_hierarchical(embeddings, min_cluster_size, target_wings)

    print("Stage 3: 3D projection...")
    coords_3d = project_3d(embeddings)

    print("Stage 4: Labeling...")
    fine_cluster_ids = sorted(set(fine_labels) - {-1})
    coarse_cluster_ids = sorted(set(coarse_map.values()))

    if use_llm:
        room_labels = label_clusters_llm(
            fine_cluster_ids, embeddings, fine_labels, documents, level="room"
        )
        # For wings, aggregate docs from all rooms in each wing
        wing_docs_map = {}
        for fc in fine_cluster_ids:
            wid = coarse_map.get(fc, 0)
            if wid not in wing_docs_map:
                wing_docs_map[wid] = []
            mask = fine_labels == fc
            wing_docs_map[wid].extend(
                documents[i][:300] for i in np.where(mask)[0][:4]
            )
        # Build a synthetic labels array for wings
        wing_labels = label_clusters_llm(
            coarse_cluster_ids, embeddings, fine_labels, documents, level="wing"
        )
    else:
        room_labels = label_clusters_keyword(fine_cluster_ids, fine_labels, documents)
        wing_labels = {wid: f"wing-{wid}" for wid in coarse_cluster_ids}

    print("Stage 5: Building topology...")
    topo = build_topology(
        ids, embeddings, documents, metadatas,
        fine_labels, coarse_map, coords_3d,
        room_labels, wing_labels,
    )
    return topo


def main():
    parser = argparse.ArgumentParser(description="Discover palace structure from vectors")
    parser.add_argument("--palace", required=True, help="Path to ChromaDB persistence dir")
    parser.add_argument("--collection", default=None, help="Collection name (auto-detects if only one)")
    parser.add_argument("--out", default="topology.json", help="Output topology file")
    parser.add_argument("--min-cluster-size", type=int, default=15)
    parser.add_argument("--target-wings", type=int, default=None, help="Target number of wings (default: auto)")
    parser.add_argument("--no-llm", action="store_true", help="Skip LLM labeling, use keywords")
    args = parser.parse_args()

    topo = discover(
        palace_path=args.palace,
        collection_name=args.collection,
        min_cluster_size=args.min_cluster_size,
        target_wings=args.target_wings,
        use_llm=not args.no_llm,
    )

    # Serialize
    out = {
        "wings": [asdict(w) for w in topo.wings],
        "rooms": [asdict(r) for r in topo.rooms],
        "documents": [asdict(d) for d in topo.documents],
        "connections": [asdict(c) for c in topo.connections],
    }
    Path(args.out).write_text(json.dumps(out, indent=2))
    print(f"\n✓ Wrote {args.out}")


if __name__ == "__main__":
    main()
