Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import json | |
| import boto3 | |
| from typing import Dict | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| from langchain.llms import SagemakerEndpoint | |
| from langchain.llms.sagemaker_endpoint import LLMContentHandler | |
| st.set_page_config(layout="centered") | |
| st.title("Arkham challenge") | |
| st.markdown(""" | |
| ## Propuesta de solución | |
| Para resolver el desafío usé RAG (Retrieval Augmented Generation). La idea general de RAG | |
| es que el modelo recupera documentos contextuales de un conjunto de datos externo como parte | |
| de su ejecución, estos documentos contextuales se utilizan junto con la entrada original | |
| para producir la salida final. | |
|  | |
| Para este caso los documentos serán particiones del contrato (aunque fácilmente podemos agregar | |
| más documentos) que extraigo con OCR. Estas particiones nos ayudarán a agregar información contextual | |
| para que una LLM pueda contestar las preguntas que le hagamos. | |
| """) | |
| with open("credentials.json") as file: | |
| credentials = json.load(file) | |
| sagemaker_client = boto3.client( | |
| service_name="sagemaker-runtime", | |
| region_name="us-east-1", | |
| aws_access_key_id=credentials["aws_access_key_id"], | |
| aws_secret_access_key=credentials["aws_secret_access_key"], | |
| ) | |
| translate_client = boto3.client( | |
| service_name="translate", | |
| region_name="us-east-1", | |
| aws_access_key_id=credentials["aws_access_key_id"], | |
| aws_secret_access_key=credentials["aws_secret_access_key"], | |
| ) | |
| st.markdown("## QA sobre el contrato") | |
| pregunta = st.text_input( | |
| label="Escribe tu pregunta", | |
| value="¿Quién es el depositario?", | |
| help=""" | |
| Escribe tu pregunta, por ejemplo: | |
| - ¿Cuáles son las obligaciones del arrendatario? | |
| - ¿Qué es FIRA? | |
| """, | |
| ) | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="intfloat/multilingual-e5-small", | |
| ) | |
| embeddings_db = FAISS.load_local("faiss_index", embeddings) | |
| retriever = embeddings_db.as_retriever(search_kwargs={"k": 5}) | |
| prompt_template = """ | |
| Please answer the question below, using only the context below. | |
| Don't invent facts, if you can't provide a factual answer, say you don't know what the answer is. | |
| question: {question} | |
| context: {context} | |
| """ | |
| prompt = PromptTemplate( | |
| template=prompt_template, input_variables=["context", "question"] | |
| ) | |
| # Endpoint de SageMaker | |
| model_kwargs = { | |
| "max_new_tokens": 512, | |
| "top_p": 0.8, | |
| "temperature": 0.8, | |
| "repetition_penalty": 1.0, | |
| } | |
| class ContentHandler(LLMContentHandler): | |
| content_type = "application/json" | |
| accepts = "application/json" | |
| def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: | |
| input_str = json.dumps( | |
| # Template de prompt para Mistral, ver https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 | |
| {"inputs": f"<s>[INST] {prompt} [/INST]", "parameters": {**model_kwargs}} | |
| ) | |
| return input_str.encode("utf-8") | |
| def transform_output(self, output: bytes) -> str: | |
| response_json = json.loads(output.read().decode("utf-8")) | |
| splits = response_json[0]["generated_text"].split("[/INST] ") | |
| return splits[1] | |
| content_handler = ContentHandler() | |
| llm = SagemakerEndpoint( | |
| endpoint_name="mistral-langchain", | |
| model_kwargs=model_kwargs, | |
| content_handler=content_handler, | |
| client=sagemaker_client, | |
| ) | |
| chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| chain_type_kwargs={"prompt": prompt}, | |
| ) | |
| question = translate_client.translate_text( | |
| Text=pregunta, | |
| SourceLanguageCode="es", | |
| TargetLanguageCode="en", | |
| Settings={ | |
| "Formality": "FORMAL", | |
| }, | |
| ).get("TranslatedText") | |
| answer = chain.run({"query": question}) | |
| respuesta = translate_client.translate_text( | |
| Text=answer, | |
| SourceLanguageCode="en", | |
| TargetLanguageCode="es", | |
| Settings={ | |
| "Formality": "FORMAL", | |
| }, | |
| ).get("TranslatedText") | |
| st.write(respuesta) |