import os
import logging

from llama_index.core import (
    SimpleDirectoryReader,
    VectorStoreIndex,
    StorageContext,
    Settings,
    get_response_synthesizer)
from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TextNode, MetadataMode
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from qdrant_client import QdrantClient

QDRANT_API_URL = os.getenv('QDRANT_API_URL')
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')


class ChatPDF:
    text_chunks = []
    doc_ids = []
    nodes = []
    hyde_query_engine = None
    logger = None

    def __init__(self):
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)

        text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)

        self.logger.info("initializing the vector store related objects")
        client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
        vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")

        self.logger.info("initializing the OllamaEmbedding")
        embed_model = OllamaEmbedding(model_name='mxbai-embed-large')
        self.logger.info("initializing the global settings")
        Settings.embed_model = embed_model
        Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
        Settings.transformations = [text_parser]

    def ingest(self, dir_path: str):
        docs = SimpleDirectoryReader(input_dir=dir_path).load_data()

        self.logger.info("enumerating docs")
        for doc_idx, doc in enumerate(docs):
            curr_text_chunks = text_parser.split_text(doc.text)
            text_chunks.extend(curr_text_chunks)
            doc_ids.extend([doc_idx] * len(curr_text_chunks))

        self.logger.info("enumerating text_chunks")
        for idx, text_chunk in enumerate(text_chunks):
            node = TextNode(text=text_chunk)
            src_doc = docs[doc_ids[idx]]
            node.metadata = src_doc.metadata
            nodes.append(node)

        self.logger.info("enumerating nodes")
        for node in nodes:
            node_embedding = embed_model.get_text_embedding(
                node.get_content(metadata_mode=MetadataMode.ALL)
            )
            node.embedding = node_embedding

        self.logger.info("initializing the storage context")
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        self.logger.info("indexing the nodes in VectorStoreIndex")
        index = VectorStoreIndex(
            nodes=nodes,
            storage_context=storage_context,
            transformations=Settings.transformations,
        )

        self.logger.info("initializing the VectorIndexRetriever with top_k as 5")
        vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=5)
        response_synthesizer = get_response_synthesizer()
        self.logger.info("creating the RetrieverQueryEngine instance")
        vector_query_engine = RetrieverQueryEngine(
            retriever=vector_retriever,
            response_synthesizer=response_synthesizer,
        )
        self.logger.info("creating the HyDEQueryTransform instance")
        hyde = HyDEQueryTransform(include_original=True)
        self.hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde)

    def ask(self, query: str):
        if not self.hyde_query_engine:
            return "Please, add a PDF document first."

        self.logger.info("retrieving the response to the query")
        response = self.hyde_query_engine.query(str_or_query_bundle=query)
        self.logger.info(response)
        return response

    def clear(self):
        self.text_chunks = []
        self.doc_ids = []
        self.nodes = []
        self.hyde_query_engine = None