import os
import fitz  # PyMuPDF
import faiss
import numpy as np
from typing import List
from sentence_transformers import SentenceTransformer

# Step 1: Load and split PDF text
def load_pdfs_from_folder(folder_path: str) -> List[str]:
    all_chunks = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".pdf"):
            full_path = os.path.join(folder_path, filename)
            print(f"Processing: {filename}")
            doc = fitz.open(full_path)
            text = ""
            for page in doc:
                text += page.get_text()
            chunks = chunk_text(text)
            all_chunks.extend(chunks)
    return all_chunks

# Step 2: Split text into chunks
def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
    words = text.split()
    chunks = []
    for i in range(0, len(words), chunk_size - overlap):
        chunk = words[i:i+chunk_size]
        chunks.append(" ".join(chunk))
    return chunks

# Step 3: Embed text chunks
def embed_chunks(chunks: List[str], model) -> np.ndarray:
    embeddings = model.encode(chunks, show_progress_bar=True, convert_to_numpy=True)
    return embeddings

# Step 4: Build FAISS index
def build_faiss_index(embeddings: np.ndarray):
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)
    return index

# Step 5: Search top-k similar chunks
def search(query: str, model, index, chunks: List[str], k: int = 3):
    query_embedding = model.encode([query], convert_to_numpy=True)
    D, I = index.search(query_embedding, k)
    return [chunks[i] for i in I[0]]

# Load embedding model
print("🔄 Loading model...")
model = SentenceTransformer("BAAI/bge-small-en")

# Load documents, embed, and index
print("📂 Reading PDFs and building vector index...")
chunks = load_pdfs_from_folder("./documents")
embeddings = embed_chunks(chunks, model)
index = build_faiss_index(embeddings)

# Chat loop
print("\n🤖 School Chatbot Ready! Ask a question, or type 'exit' to quit.\n")
while True:
    query = input("You: ")
    if query.lower() in ["exit", "quit"]:
        print("👋 Exiting. Goodbye!")
        break
    results = search(query, model, index, chunks, k=3)
    print("\n📄 Top Results:\n")
    for i, res in enumerate(results, 1):
        print(f"{i}. {res}\n")
