import os
import glob
from typing import List
import pypdf
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
import json
import ollama  # 🧠 For local LLM via Ollama

# Load environment variables
load_dotenv()

# Initialize embeddings model
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    model_kwargs={'device': 'cpu'}
)

DOCUMENTS_DIR = "documents"

def load_documents(directory: str) -> List[Document]:
    """Load Q&A pairs from all JSON files in the specified directory. Ignore PDFs."""
    documents = []
    # Load JSON Q&A
    json_files = glob.glob(os.path.join(directory, "*.json"))
    for json_path in json_files:
        try:
            print(f"Processing {json_path}...")
            with open(json_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                if isinstance(data, dict) and 'qas' in data:
                    qas = data['qas']
                else:
                    qas = data
                for qa in qas:
                    question = qa.get('question') or qa.get('q')
                    answer = qa.get('answer') or qa.get('a')
                    if question and answer:
                        documents.append(Document(
                            page_content=question,
                            metadata={
                                "answer": answer,
                                "source": json_path
                            }
                        ))
            print(f"✅ Successfully processed {json_path}")
        except Exception as e:
            print(f"❌ Error processing {json_path}: {str(e)}")
    return documents

def create_vectorstore(documents: List[Document]) -> FAISS:
    """Create a FAISS vectorstore from the documents."""
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=50
    )
    texts = text_splitter.split_documents(documents)
    vectorstore = FAISS.from_documents(texts, embeddings)
    os.makedirs("vectorstore", exist_ok=True)
    vectorstore.save_local("vectorstore/db")
    return vectorstore

def load_or_create_vectorstore() -> FAISS:
    """Load existing vectorstore or create a new one if it doesn't exist."""
    if os.path.exists("vectorstore/db"):
        print("Loading existing vectorstore...")
        return FAISS.load_local("vectorstore/db", embeddings, allow_dangerous_deserialization=True)
    print("Creating new vectorstore...")
    documents = load_documents(DOCUMENTS_DIR)
    return create_vectorstore(documents)

def get_relevant_context(query: str, vectorstore: FAISS) -> str:
    """Get relevant context for the query from the vectorstore."""
    docs = vectorstore.similarity_search(query, k=3)
    return "\n\n".join([
        f"Q: {doc.page_content}\nA: {doc.metadata.get('answer', '')}" for doc in docs
    ])

def generate_response(query: str, context: str) -> str:
    """Generate a response using the local TinyLlama model via Ollama."""
    prompt = f"""You are a helpful AI assistant. Use the following context to answer the question. 
If you cannot answer the question based on the context, say so.

Context:
{context}

Question: {query}

Answer:"""

    response = ollama.chat(
        model="tinyllama",
        messages=[{"role": "user", "content": prompt}]
    )

    return response['message']['content']

def main():
    # Initialize vectorstore
    vectorstore = load_or_create_vectorstore()
    
    print("\n🧠 RAG Chatbot (TinyLlama) Initialized! Type 'quit' to exit.")
    
    while True:
        query = input("\nYour question: ").strip()
        
        if query.lower() == 'quit':
            break
        
        if not query:
            continue
        
        try:
            # Get relevant context
            context = get_relevant_context(query, vectorstore)
            
            # Generate response
            response = generate_response(query, context)
            
            print("\nAnswer:", response)
            
        except Exception as e:
            print(f"\n❌ Error: {str(e)}")

if __name__ == "__main__":
    main()
