import os
import glob
import uuid
import json
from datetime import datetime
from typing import List, Optional
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse, RedirectResponse
from pydantic import BaseModel
from jose import JWTError, jwt
from cachetools import LRUCache
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from database import (
    insert_conversation,
    get_session_list,
    get_session_history,
    get_settings,
    update_settings
)
from admin import admin_router, SECRET_KEY, ALGORITHM, get_admin_by_email
from models import ChatbotSettings, ModelProvider
import asyncio
import concurrent.futures
from groq import Groq
from dotenv import load_dotenv
import re

# Load environment variables
load_dotenv()

# Initialize Groq client
client = Groq(
    api_key=os.getenv("GROQ_API_KEY")
)

# Initialize app
app = FastAPI(title="DLC Chatbot API", version="1.0.0")
app.include_router(admin_router)

templates = Jinja2Templates(directory="templates")
os.makedirs("static", exist_ok=True)
app.mount("/static", StaticFiles(directory="static"), name="static")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# Vectorstore from all JSON files in /documents
qa_data = []
documents_dir = "documents"
json_files = glob.glob(os.path.join(documents_dir, "*.json"))
for json_path in json_files:
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            # Handle various JSON structures
            if isinstance(data, dict):
                for key, items in data.items():
                    if isinstance(items, list):
                        for item in items:
                            if isinstance(item, dict) and 'intent' in item and 'questions' in item:
                                answers = item.get('answers', [item.get('answer', '')])
                                if not isinstance(answers, list):
                                    answers = [answers]
                                for q in item['questions']:
                                    qa_data.append({
                                        'intent': item['intent'],
                                        'question': q,
                                        'answer': answers[0] if answers else ''
                                    })
            elif isinstance(data, list):
                for item in data:
                    if isinstance(item, dict) and 'intent' in item and 'questions' in item:
                        answers = item.get('answers', [item.get('answer', '')])
                        if not isinstance(answers, list):
                            answers = [answers]
                        for q in item['questions']:
                            qa_data.append({
                                'intent': item['intent'],
                                'question': q,
                                'answer': answers[0] if answers else ''
                            })
    except Exception as e:
        print(f"Error processing {json_path}: {str(e)}")

# Load PDF paragraph chunks
chunk_texts = []
chunk_files = glob.glob(os.path.join(documents_dir, "*_chunks.json"))
for chunk_file in chunk_files:
    try:
        with open(chunk_file, "r", encoding="utf-8") as f:
            chunks = json.load(f)
            chunk_texts.extend([chunk["chunk"] for chunk in chunks if len(chunk["chunk"]) > 50])
    except Exception as e:
        print(f"Error processing {chunk_file}: {str(e)}")

# Combine for embedding
qa_texts = [qa['question'] for qa in qa_data]
all_texts = qa_texts + chunk_texts
vectorstore = FAISS.from_texts(all_texts, embedding_model)

# Cache
settings = get_settings() or ChatbotSettings().dict()
CACHE_MAX_SIZE = settings.get("cache_max_size", 1000)
RESPONSE_CACHE = LRUCache(maxsize=CACHE_MAX_SIZE)

# Thread executor
executor = concurrent.futures.ThreadPoolExecutor()

# Models
class QuestionRequest(BaseModel):
    question: str
    session_id: Optional[str] = None

class ChatResponse(BaseModel):
    answer: str
    session_id: str

class SessionResponse(BaseModel):
    session_id: str
    created_at: datetime

class SessionListResponse(BaseModel):
    sessions: List[SessionResponse]

class Question(BaseModel):
    query: str

class Answer(BaseModel):
    answer: str
    intent: Optional[str] = None
    confidence: Optional[float] = None

# Function to rebuild vector store
def rebuild_vectorstore():
    global qa_data, vectorstore
    documents_dir = "documents"
    qa_data = []
    json_files = glob.glob(os.path.join(documents_dir, "*.json"))
    
    for json_path in json_files:
        try:
            with open(json_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                # Handle various JSON structures
                if isinstance(data, dict):
                    for key, items in data.items():
                        if isinstance(items, list):
                            for item in items:
                                if isinstance(item, dict) and 'intent' in item and 'questions' in item:
                                    answers = item.get('answers', [item.get('answer', '')])
                                    if not isinstance(answers, list):
                                        answers = [answers]
                                    for q in item['questions']:
                                        qa_data.append({
                                            'intent': item['intent'],
                                            'question': q,
                                            'answer': answers[0] if answers else ''
                                        })
                elif isinstance(data, list):
                    for item in data:
                        if isinstance(item, dict) and 'intent' in item and 'questions' in item:
                            answers = item.get('answers', [item.get('answer', '')])
                            if not isinstance(answers, list):
                                answers = [answers]
                            for q in item['questions']:
                                qa_data.append({
                                    'intent': item['intent'],
                                    'question': q,
                                    'answer': answers[0] if answers else ''
                                })
        except Exception as e:
            print(f"Error processing {json_path}: {str(e)}")
    
    # Load PDF paragraph chunks
    chunk_texts = []
    chunk_files = glob.glob(os.path.join(documents_dir, "*_chunks.json"))
    for chunk_file in chunk_files:
        try:
            with open(chunk_file, "r", encoding="utf-8") as f:
                chunks = json.load(f)
                chunk_texts.extend([chunk["chunk"] for chunk in chunks if len(chunk["chunk"]) > 50])
        except Exception as e:
            print(f"Error processing {chunk_file}: {str(e)}")
    
    qa_texts = [qa['question'] for qa in qa_data]
    all_texts = qa_texts + chunk_texts
    vectorstore = FAISS.from_texts(all_texts, embedding_model)
    print(f"Vector store rebuilt with {len(qa_data)} Q&A pairs and {len(chunk_texts)} PDF paragraph chunks from {len(json_files)} Q&A files and {len(chunk_files)} chunk files")

# Check for rebuild flag
def check_rebuild_flag():
    rebuild_flag_file = "vectorstore_rebuild_flag.txt"
    if os.path.exists(rebuild_flag_file):
        try:
            os.remove(rebuild_flag_file)
            rebuild_vectorstore()
            return True
        except Exception as e:
            print(f"Error rebuilding vector store: {str(e)}")
    return False

# Get relevant context from vectorstore
async def get_relevant_context(query: str, k: int = 3) -> str:
    # Check if vector store needs rebuilding
    check_rebuild_flag()
    docs = vectorstore.similarity_search(query, k=1)  # Only get the top-1 result
    if not docs:
        return ""
    doc = docs[0]
    # Try to match with Q&A
    match = next((qa['answer'] for qa in qa_data if qa['question'] == doc.page_content), None)
    if match:
        return match
    else:
        # It's a PDF chunk, just return the chunk text
        return doc.page_content

# Generate response using Groq LLM
async def generate_response(query: str, context: str) -> str:
    prompt = f"""You are a helpful AI assistant. Use the following context to answer the question. \nIf you cannot answer the question based on the context, say so.\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:"""
    loop = asyncio.get_event_loop()
    def call_groq():
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            model="deepseek-r1-distill-llama-70b",
            temperature=0.7,
            max_tokens=1000,
            top_p=0.9,
        )
        return chat_completion.choices[0].message.content
    raw_answer = await loop.run_in_executor(executor, call_groq)
    # Remove any leading 'Context:', 'Question:', 'Answer:' labels and any <think>...</think> tags
    answer = re.sub(r'^(Context:|Question:|Answer:)+', '', raw_answer, flags=re.IGNORECASE).strip()
    answer = re.sub(r'<think>.*?</think>', '', answer, flags=re.DOTALL|re.IGNORECASE).strip()
    return answer

@app.post("/api/chat", response_model=ChatResponse)
async def chat(request: QuestionRequest):
    try:
        session_id = request.session_id or str(uuid.uuid4())
        cache_key = request.question.strip().lower()
        if settings.get('enable_response_cache', True) and cache_key in RESPONSE_CACHE:
            cached_answer = RESPONSE_CACHE[cache_key]
            insert_conversation(session_id, request.question, cached_answer, source="cache")
            return ChatResponse(answer=cached_answer, session_id=session_id)

        context = await get_relevant_context(request.question)
        answer = await generate_response(request.question, context)

        if settings.get('enable_response_cache', True) and len(request.question) < 200:
            RESPONSE_CACHE[cache_key] = answer

        if settings.get('log_conversations', True):
            insert_conversation(session_id, request.question, answer, source="groq")
        
        return ChatResponse(answer=answer, session_id=session_id)
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/api/sessions", response_model=SessionListResponse)
async def list_sessions():
    try:
        sessions = get_session_list()
        return {"sessions": sessions}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error retrieving sessions: {str(e)}")

@app.get("/api/sessions/{session_id}/history")
async def get_chat_history(session_id: str):
    try:
        history = get_session_history(session_id)
        for item in history:
            item["timestamp"] = item["timestamp"].isoformat()
        return {"session_id": session_id, "history": history}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error retrieving session history: {str(e)}")

@app.get("/", response_class=HTMLResponse)
async def chat_ui(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.get("/admin/login", response_class=HTMLResponse)
async def admin_login(request: Request):
    return templates.TemplateResponse("admin_login.html", {"request": request})

@app.get("/admin", response_class=HTMLResponse)
async def admin_dashboard(request: Request):
    try:
        admin = await get_current_admin(request)
        settings = await get_chatbot_settings(admin)
        return templates.TemplateResponse("admin_dashboard.html", {"request": request, "admin": admin, "settings": settings})
    except HTTPException:
        return RedirectResponse(url="/admin/login", status_code=303)

async def get_current_admin(request: Request):
    token = request.cookies.get("admin_token") or request.headers.get("Authorization", "").replace("Bearer ", "")
    if not token:
        raise HTTPException(status_code=401, detail="Not authenticated")
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        email = payload.get("sub")
        if not email:
            raise HTTPException(status_code=401, detail="Invalid token")
        admin = get_admin_by_email(email)
        if not admin:
            raise HTTPException(status_code=401, detail="Admin not found")
        return admin
    except JWTError:
        raise HTTPException(status_code=401, detail="Invalid token")

@app.get("/api/settings", response_model=ChatbotSettings)
async def get_chatbot_settings(admin: dict = Depends(get_current_admin)):
    settings_data = get_settings()
    if not settings_data:
        raise HTTPException(status_code=404, detail="Settings not found")
    return ChatbotSettings(**settings_data)

@app.put("/api/settings")
async def update_chatbot_settings(settings: ChatbotSettings, admin: dict = Depends(get_current_admin)):
    if update_settings(settings.dict()):
        return {"message": "Settings updated successfully"}
    raise HTTPException(status_code=500, detail="Error updating settings")

@app.post("/ask", response_model=Answer)
async def ask_question(question: Question):
    if not question.query:
        raise HTTPException(status_code=400, detail="Query cannot be empty")

    context = await get_relevant_context(question.query)
    answer = await generate_response(question.query, context)
    return Answer(answer=answer)

@app.get("/intents")
async def list_intents():
    intents = list(set(qa['intent'] for qa in qa_data))
    return {"total_intents": len(intents), "intents": intents}

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "qa_pairs_loaded": len(qa_data),
        "intents_available": len(set(qa['intent'] for qa in qa_data))
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000, loop="uvloop", http="httptools", workers=4) 