import streamlit as st import json import boto3 from typing import Dict from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain.llms import SagemakerEndpoint from langchain.llms.sagemaker_endpoint import LLMContentHandler st.set_page_config(layout="centered") st.title("Arkham challenge") st.markdown(""" ## Propuesta de solución Para resolver el desafío usé RAG (Retrieval Augmented Generation). La idea general de RAG es que el modelo recupera documentos contextuales de un conjunto de datos externo como parte de su ejecución, estos documentos contextuales se utilizan junto con la entrada original para producir la salida final. ![RAG](https://huggingface.co/blog/assets/12_ray_rag/rag_gif.gif "RAG") Para este caso los documentos serán particiones del contrato (aunque fácilmente podemos agregar más documentos) que extraigo con OCR. Estas particiones nos ayudarán a agregar información contextual para que una LLM pueda contestar las preguntas que le hagamos. """) with open("credentials.json") as file: credentials = json.load(file) sagemaker_client = boto3.client( service_name="sagemaker-runtime", region_name="us-east-1", aws_access_key_id=credentials["aws_access_key_id"], aws_secret_access_key=credentials["aws_secret_access_key"], ) translate_client = boto3.client( service_name="translate", region_name="us-east-1", aws_access_key_id=credentials["aws_access_key_id"], aws_secret_access_key=credentials["aws_secret_access_key"], ) st.markdown("## QA sobre el contrato") pregunta = st.text_input( label="Escribe tu pregunta", value="¿Quién es el depositario?", help=""" Escribe tu pregunta, por ejemplo: - ¿Cuáles son las obligaciones del arrendatario? - ¿Qué es FIRA? """, ) embeddings = HuggingFaceEmbeddings( model_name="intfloat/multilingual-e5-small", ) embeddings_db = FAISS.load_local("faiss_index", embeddings) retriever = embeddings_db.as_retriever(search_kwargs={"k": 5}) prompt_template = """ Please answer the question below, using only the context below. Don't invent facts, if you can't provide a factual answer, say you don't know what the answer is. question: {question} context: {context} """ prompt = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) # Endpoint de SageMaker model_kwargs = { "max_new_tokens": 512, "top_p": 0.8, "temperature": 0.8, "repetition_penalty": 1.0, } class ContentHandler(LLMContentHandler): content_type = "application/json" accepts = "application/json" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: input_str = json.dumps( # Template de prompt para Mistral, ver https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 {"inputs": f"[INST] {prompt} [/INST]", "parameters": {**model_kwargs}} ) return input_str.encode("utf-8") def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) splits = response_json[0]["generated_text"].split("[/INST] ") return splits[1] content_handler = ContentHandler() llm = SagemakerEndpoint( endpoint_name="mistral-langchain", model_kwargs=model_kwargs, content_handler=content_handler, client=sagemaker_client, ) chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, chain_type_kwargs={"prompt": prompt}, ) question = translate_client.translate_text( Text=pregunta, SourceLanguageCode="es", TargetLanguageCode="en", Settings={ "Formality": "FORMAL", }, ).get("TranslatedText") answer = chain.run({"query": question}) respuesta = translate_client.translate_text( Text=answer, SourceLanguageCode="en", TargetLanguageCode="es", Settings={ "Formality": "FORMAL", }, ).get("TranslatedText") st.write(respuesta)