import uuid
from typing import List
import PyPDF2
import docx
from sqlalchemy import create_engine, Column, ForeignKey, Text
from sqlalchemy.dialects.postgresql import UUID as pg_UUID
from sqlalchemy.orm import declarative_base, sessionmaker, Session
from pgvector.sqlalchemy import Vector
from openai import OpenAI

from config import settings

engine = create_engine(settings.DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

client = OpenAI(api_key=settings.OPENAI_API_KEY)

class DocumentChunk(Base):
    __tablename__ = "document_chunks"
    id = Column(pg_UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
    project_id = Column(pg_UUID(as_uuid=True)) 
    document_id = Column(pg_UUID(as_uuid=True))
    content = Column(Text, nullable=False)
    embedding = Column(Vector(1536)) 

Base.metadata.create_all(bind=engine)

def extract_text(file_path: str, file_type: str) -> str:
    text = ""
    if file_type == 'pdf':
        with open(file_path, 'rb') as f:
            reader = PyPDF2.PdfReader(f)
            for page in reader.pages:
                extracted = page.extract_text()
                if extracted:
                    text += extracted + "\n"
    elif file_type == 'docx':
        doc = docx.Document(file_path)
        for para in doc.paragraphs:
            text += para.text + "\n"
    elif file_type == 'txt':
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
    return text

def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
    chunks = []
    start = 0
    text_length = len(text)
    while start < text_length:
        end = start + chunk_size
        chunks.append(text[start:end])
        start += chunk_size - overlap
    return chunks

def process_and_store_document(project_id: str, document_id: str, file_path: str, file_type: str):
    text = extract_text(file_path, file_type)
    chunks = chunk_text(text)
    
    db: Session = SessionLocal()
    try:
        for chunk in chunks:
            response = client.embeddings.create(
                input=chunk,
                model="text-embedding-3-small"
            )
            embedding_vector = response.data[0].embedding
            
            doc_chunk = DocumentChunk(
                project_id=project_id,
                document_id=document_id,
                content=chunk,
                embedding=embedding_vector
            )
            db.add(doc_chunk)
        db.commit()
    except Exception as e:
        db.rollback()
        print(f"Error processing document: {e}")
    finally:
        db.close()

def search_knowledge_base(project_id: str, query: str, limit: int = 3) -> str:
    response = client.embeddings.create(
        input=query,
        model="text-embedding-3-small"
    )
    query_embedding = response.data[0].embedding

    db: Session = SessionLocal()
    try:
        results = db.query(DocumentChunk).filter(
            DocumentChunk.project_id == project_id
        ).order_by(
            DocumentChunk.embedding.cosine_distance(query_embedding)
        ).limit(limit).all()
        
        context = "\n\n".join([r.content for r in results])
        return context
    finally:
        db.close()