import os

from openai import OpenAI
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS, Chroma
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chains.question_answering import load_qa_chain
import streamlit as st
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # read local .env file
 
#Variables globales
embeddings = None
modelo_seleccionado = None
simil_seleccionado = None
tipologia_seleccionado = None
OpenAI.api_key  = os.environ['OPENAI_API_KEY']
 
def creacion_front():
      global embeddings
      global simil_seleccionado
      global tipologia_seleccionado
      st.set_page_config('Prueba generica')
      st.header("Tester IA")      
      # Crear la segunda fila con botones de radio
      first_row_data = []
      # Para la columna "Modelo"
      modelo_options = ["paraphrase-multilingual", "OpenAI", "multilingual-mpnet-base-v2"]
      selected_modelo = st.radio("Selecciona modelo de vectorización", modelo_options)
      first_row_data.append(selected_modelo)
      # Para la columna "tipologia"
      tipologia_options = ["Similarity search", "Maximun marginal revelance (MMR)"]
      selected_tipologia = st.radio("Selecciona tipologia de busqueda de similitudes", tipologia_options)
      first_row_data.append(selected_tipologia)
      # Para la columna "LLM"
      simil_options = ["FAISS", "Chroma"]
      selected_simil = st.radio("Selecciona modelo de busqueda de similitudes:", simil_options)
      first_row_data.append(selected_simil)      
     
      # Obtener los valores seleccionados
      modelo_seleccionado = first_row_data[0]
      simil_seleccionado = first_row_data[1]
      tipologia_seleccionado = first_row_data[2]
           
      if modelo_seleccionado == "paraphrase-multilingual":
            embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
      elif modelo_seleccionado == "multilingual-mpnet-base-v2":
            embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
      else:
            embeddings = OpenAIEmbeddings()                                  
     
def tratamiento_texto(pdf):
      pdf_reader = PdfReader(pdf)
      text = ""
      for page in pdf_reader.pages:
            text += page.extract_text()
      return text
 
def proceso_chunk(text):
      text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size_param,
            chunk_overlap=chunk_overlap_param,
            length_function=len
      )
      #Dvision en chunks
      return text_splitter.split_text(text)
 
def proceso_vectorizacion(parrafos_chunks):
      global knowledge_base
      #Obtencion embeddings
      if tipologia_seleccionado == "FAISS":
            knowledge_base = FAISS.from_texts(parrafos_chunks, embeddings)
      else:
            knowledge_base = Chroma.from_texts(parrafos_chunks, embeddings)
      return knowledge_base
     
def proceso_seleccion_chunks(knowledge_base, pregunta):
      if simil_seleccionado == "Similarity search":
            chunks_seleccionados = knowledge_base.similarity_search(pregunta, num_chunk_to_llm_param)
      elif simil_seleccionado == "Maximun marginal revelance (MMR)":
            chunks_seleccionados = knowledge_base.max_marginal_relevance_search(pregunta, num_chunk_to_llm_param)
      else:
            chunks_seleccionados = None
      for i, elemento in enumerate(chunks_seleccionados):
            texto=f"Chunk[{i}] -> {elemento}"
            #st.write(f"Chunk[{i}] -> {elemento}")        
            st.info(texto,icon="🚨")
      return chunks_seleccionados
 
def busqueda_con_llm(entrada,pregunta):
      #Definicion modelo
      llm = ChatOpenAI(model_name='gpt-3.5-turbo')
      #PRecarga QA
      chain = load_qa_chain(llm, chain_type="stuff")
      respuesta = chain.run(input_documents=entrada, question=pregunta)
      st.write(respuesta)  
     
 
@st.cache_resource
def obtencion_vectores():
      texto_tratado = tratamiento_texto(pdf_obj)
      parrafos_chunks = proceso_chunk(texto_tratado)
      knowledge_base = proceso_vectorizacion(parrafos_chunks)
      return knowledge_base
     
if __name__ == "__main__":
      creacion_front()
      chunk_size_param=st.slider('Chunk size param?', 0, 2000, 800)
      chunk_overlap_param=st.slider('Chunk overlap param?', 0, 500, 100)
      num_chunk_to_llm_param=st.slider('Number of chunks to the LLM?', 0, 20, 5)
      pdf_obj = st.file_uploader("Carga tu documento", type="pdf", on_change=st.cache_resource.clear)
      if pdf_obj:
            knowledge_base = obtencion_vectores()
            pregunta_chunks = st.text_input("Haz una pregunta para generar los chunks:")
            if pregunta_chunks:
                  chunks_para_llm = proceso_seleccion_chunks(knowledge_base, pregunta_chunks)
                  pregunta_llm = st.text_input("Haz una pregunta a ChatGPT:")
                  if pregunta_llm:
                        busqueda_con_llm(chunks_para_llm,pregunta_llm)