import os
import glob
from typing import List
import pypdf
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from groq import Groq
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Initialize Groq client
client = Groq(
    api_key=os.getenv("GROQ_API_KEY")
)

# Initialize embeddings model
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    model_kwargs={'device': 'cpu'}
)

DOCUMENTS_DIR = "documents"

def load_pdfs(directory: str) -> List[Document]:
    """Load PDF documents from the specified directory."""
    documents = []
    pdf_files = glob.glob(os.path.join(directory, "*.pdf"))
    
    for pdf_path in pdf_files:
        try:
            print(f"Processing {pdf_path}...")
            pdf_reader = pypdf.PdfReader(pdf_path)
            for i, page in enumerate(pdf_reader.pages):
                text = page.extract_text() or ""
                if text.strip():
                    documents.append(Document(
                        page_content=text,
                        metadata={"source": pdf_path, "page": i + 1}
                    ))
            print(f"✅ Successfully processed {pdf_path}")
        except Exception as e:
            print(f"❌ Error processing {pdf_path}: {str(e)}")
    
    return documents

def create_vectorstore(documents: List[Document]) -> FAISS:
    """Create a FAISS vectorstore from the documents."""
    # Split documents into chunks
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=50
    )
    texts = text_splitter.split_documents(documents)
    
    # Create vectorstore
    vectorstore = FAISS.from_documents(texts, embeddings)
    
    # Save vectorstore
    os.makedirs("vectorstore", exist_ok=True)
    vectorstore.save_local("vectorstore/db")
    
    return vectorstore

def load_or_create_vectorstore() -> FAISS:
    """Load existing vectorstore or create a new one if it doesn't exist."""
    if os.path.exists("vectorstore/db"):
        print("Loading existing vectorstore...")
        return FAISS.load_local("vectorstore/db", embeddings, allow_dangerous_deserialization=True)
    
    print("Creating new vectorstore...")
    documents = load_pdfs(DOCUMENTS_DIR)
    return create_vectorstore(documents)

def get_relevant_context(query: str, vectorstore: FAISS) -> str:
    """Get relevant context for the query from the vectorstore."""
    docs = vectorstore.similarity_search(query, k=3)
    return "\n\n".join([doc.page_content for doc in docs])

def generate_response(query: str, context: str) -> str:
    """Generate a response using Groq's API."""
    prompt = f"""You are a helpful AI assistant. Use the following context to answer the question. 
If you cannot answer the question based on the context, say so.

Context:
{context}

Question: {query}

Answer:"""

    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": prompt
            }
        ],
        model="deepseek-r1-distill-llama-70b",
        temperature=0.7,
        max_tokens=1000,
        top_p=0.9,
    )
    
    return chat_completion.choices[0].message.content

def main():
    # Initialize vectorstore
    vectorstore = load_or_create_vectorstore()
    
    print("\nRAG Chatbot initialized! Type 'quit' to exit.")
    
    while True:
        query = input("\nYour question: ").strip()
        
        if query.lower() == 'quit':
            break
        
        if not query:
            continue
        
        try:
            # Get relevant context
            context = get_relevant_context(query, vectorstore)
            
            # Generate response
            response = generate_response(query, context)
            
            print("\nAnswer:", response)
            
        except Exception as e:
            print(f"\n❌ Error: {str(e)}")

if __name__ == "__main__":
    main() 