davidguzmanr's picture
Update app
6319a21 verified
import streamlit as st
import json
import boto3
from typing import Dict
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
st.set_page_config(layout="centered")
st.title("Arkham challenge")
st.markdown("""
## Propuesta de solución
Para resolver el desafío usé RAG (Retrieval Augmented Generation). La idea general de RAG
es que el modelo recupera documentos contextuales de un conjunto de datos externo como parte
de su ejecución, estos documentos contextuales se utilizan junto con la entrada original
para producir la salida final.
![RAG](https://huggingface.co/blog/assets/12_ray_rag/rag_gif.gif "RAG")
Para este caso los documentos serán particiones del contrato (aunque fácilmente podemos agregar
más documentos) que extraigo con OCR. Estas particiones nos ayudarán a agregar información contextual
para que una LLM pueda contestar las preguntas que le hagamos.
""")
with open("credentials.json") as file:
credentials = json.load(file)
sagemaker_client = boto3.client(
service_name="sagemaker-runtime",
region_name="us-east-1",
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
)
translate_client = boto3.client(
service_name="translate",
region_name="us-east-1",
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
)
st.markdown("## QA sobre el contrato")
pregunta = st.text_input(
label="Escribe tu pregunta",
value="¿Quién es el depositario?",
help="""
Escribe tu pregunta, por ejemplo:
- ¿Cuáles son las obligaciones del arrendatario?
- ¿Qué es FIRA?
""",
)
embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-small",
)
embeddings_db = FAISS.load_local("faiss_index", embeddings)
retriever = embeddings_db.as_retriever(search_kwargs={"k": 5})
prompt_template = """
Please answer the question below, using only the context below.
Don't invent facts, if you can't provide a factual answer, say you don't know what the answer is.
question: {question}
context: {context}
"""
prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
# Endpoint de SageMaker
model_kwargs = {
"max_new_tokens": 512,
"top_p": 0.8,
"temperature": 0.8,
"repetition_penalty": 1.0,
}
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
# Template de prompt para Mistral, ver https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
{"inputs": f"<s>[INST] {prompt} [/INST]", "parameters": {**model_kwargs}}
)
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
splits = response_json[0]["generated_text"].split("[/INST] ")
return splits[1]
content_handler = ContentHandler()
llm = SagemakerEndpoint(
endpoint_name="mistral-langchain",
model_kwargs=model_kwargs,
content_handler=content_handler,
client=sagemaker_client,
)
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
)
question = translate_client.translate_text(
Text=pregunta,
SourceLanguageCode="es",
TargetLanguageCode="en",
Settings={
"Formality": "FORMAL",
},
).get("TranslatedText")
answer = chain.run({"query": question})
respuesta = translate_client.translate_text(
Text=answer,
SourceLanguageCode="en",
TargetLanguageCode="es",
Settings={
"Formality": "FORMAL",
},
).get("TranslatedText")
st.write(respuesta)