import os
import fitz  # PyMuPDF
import faiss
import numpy as np
import pickle
from typing import List
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# === PDF Loading and Chunking ===
def load_pdfs(folder_path: str) -> List[str]:
    chunks = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".pdf"):
            doc = fitz.open(os.path.join(folder_path, filename))
            text = ""
            for page in doc:
                text += page.get_text()
            chunks.extend(split_text(text))
    return chunks

def split_text(text: str, chunk_size=300, overlap=50) -> List[str]:
    words = text.split()
    return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size - overlap)]

# === Embedding and Indexing ===
def embed_chunks(chunks, model) -> np.ndarray:
    return model.encode(chunks, show_progress_bar=True, convert_to_numpy=True)

def build_index(embeddings: np.ndarray):
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)
    return index

def save_index(index, chunks):
    os.makedirs("index", exist_ok=True)
    faiss.write_index(index, "index/faiss.index")
    with open("index/chunks.pkl", "wb") as f:
        pickle.dump(chunks, f)

def load_index_and_chunks():
    if os.path.exists("index/faiss.index") and os.path.exists("index/chunks.pkl"):
        print("✅ Found saved FAISS index. Loading...")
        index = faiss.read_index("index/faiss.index")
        with open("index/chunks.pkl", "rb") as f:
            chunks = pickle.load(f)
        return index, chunks
    return None, None

# === Query Retrieval ===
def retrieve(query, embedder, index, chunks, k=3):
    query_vector = embedder.encode([query], convert_to_numpy=True)
    D, I = index.search(query_vector, k)
    return [chunks[i] for i in I[0]]

# === Phi-3 Reasoning Model ===
def load_phi3_model():
    model_id = "microsoft/phi-3-mini-4k-instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto"
    )
    return tokenizer, model

def generate_answer(query, context_chunks, tokenizer, model):
    context = "\n\n".join(context_chunks)
    prompt = f"""Answer the question based only on the context below.

### Context:
{context}

### Question:
{query}

### Answer:"""

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    output = model.generate(**inputs, max_new_tokens=256, temperature=0.7, top_p=0.9)
    return tokenizer.decode(output[0], skip_special_tokens=True).split("### Answer:")[-1].strip()

# === Main Program ===
if __name__ == "__main__":
    print("🔎 Loading BGE-small embedding model...")
    embedder = SentenceTransformer("BAAI/bge-small-en")

    index, chunks = load_index_and_chunks()
    if index is None:
        print("📚 No index found. Processing PDFs...")
        chunks = load_pdfs("./documents")
        embeddings = embed_chunks(chunks, embedder)
        index = build_index(embeddings)
        save_index(index, chunks)
    else:
        print("✅ FAISS index ready.")

    print("🧠 Loading Phi-3 reasoning model...")
    tokenizer, phi3 = load_phi3_model()

    print("\n🤖 School Chatbot Ready! Ask a question or type 'exit'.\n")
    while True:
        query = input("You: ")
        if query.lower() in ["exit", "quit"]:
            print("👋 Goodbye!")
            break
        top_chunks = retrieve(query, embedder, index, chunks)
        answer = generate_answer(query, top_chunks, tokenizer, phi3)
        print(f"\n🤖 {answer}\n")
