Spaces:
Sleeping
Sleeping
davidguzmanr
commited on
Commit
•
7b2ccf1
1
Parent(s):
b0194fa
Add app.py
Browse files- app.py +138 -0
- credentials.json +4 -0
- deployment/QA with OCR deployment.ipynb +521 -0
- deployment/credentials.json +4 -0
- faiss_index/index.faiss +0 -0
- faiss_index/index.pkl +3 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
import json
|
4 |
+
import boto3
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
from langchain.vectorstores import FAISS
|
8 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
9 |
+
from langchain.chains import RetrievalQA
|
10 |
+
from langchain.prompts import PromptTemplate
|
11 |
+
from langchain.llms import SagemakerEndpoint
|
12 |
+
from langchain.llms.sagemaker_endpoint import LLMContentHandler
|
13 |
+
|
14 |
+
st.set_page_config(layout="centered")
|
15 |
+
st.title("Arkham challenge")
|
16 |
+
|
17 |
+
st.markdown("""
|
18 |
+
## Propuesta de solución
|
19 |
+
Para resolver el desafío usé RAG (Retrieval Augmented Generation). La idea general de RAG
|
20 |
+
es que el modelo recupera documentos contextuales de un conjunto de datos externo como parte
|
21 |
+
de su ejecución, estos documentos contextuales se utilizan junto con la entrada original
|
22 |
+
para producir la salida final.
|
23 |
+
|
24 |
+
![RAG](https://huggingface.co/blog/assets/12_ray_rag/rag_gif.gif "RAG")
|
25 |
+
|
26 |
+
Para este caso los documentos serán particiones del contrato (aunque fácilmente podemos agregar
|
27 |
+
más documentos) que extraigo con OCR. Estas particiones nos ayudarán a agregar información contextual
|
28 |
+
para que una LLM pueda contestar las preguntas que le hagamos.
|
29 |
+
""")
|
30 |
+
|
31 |
+
with open("credentials.json") as file:
|
32 |
+
credentials = json.load(file)
|
33 |
+
|
34 |
+
sagemaker_client = boto3.client(
|
35 |
+
service_name="sagemaker-runtime",
|
36 |
+
region_name="us-east-1",
|
37 |
+
aws_access_key_id=credentials["aws_access_key_id"],
|
38 |
+
aws_secret_access_key=credentials["aws_secret_access_key"],
|
39 |
+
)
|
40 |
+
|
41 |
+
translate_client = boto3.client(
|
42 |
+
service_name="translate",
|
43 |
+
region_name="us-east-1",
|
44 |
+
aws_access_key_id=credentials["aws_access_key_id"],
|
45 |
+
aws_secret_access_key=credentials["aws_secret_access_key"],
|
46 |
+
)
|
47 |
+
|
48 |
+
st.markdown("## QA sobre el contrato")
|
49 |
+
pregunta = st.text_input(
|
50 |
+
label="Escribe tu pregunta",
|
51 |
+
value="¿Quién es el depositario?",
|
52 |
+
help="""
|
53 |
+
Escribe tu pregunta, por ejemplo:
|
54 |
+
- ¿Cuáles son las obligaciones del arrendatario?
|
55 |
+
- ¿Qué es FIRA?
|
56 |
+
""",
|
57 |
+
)
|
58 |
+
|
59 |
+
embeddings = HuggingFaceEmbeddings(
|
60 |
+
model_name="intfloat/multilingual-e5-small",
|
61 |
+
)
|
62 |
+
embeddings_db = FAISS.load_local("faiss_index", embeddings)
|
63 |
+
retriever = embeddings_db.as_retriever(search_kwargs={"k": 2})
|
64 |
+
|
65 |
+
prompt_template = """
|
66 |
+
Please answer the question below, using only the context below.
|
67 |
+
Don't invent facts, if you can't provide a factual answer, say you don't know what the answer is.
|
68 |
+
|
69 |
+
question: {question}
|
70 |
+
|
71 |
+
context: {context}
|
72 |
+
"""
|
73 |
+
prompt = PromptTemplate(
|
74 |
+
template=prompt_template, input_variables=["context", "question"]
|
75 |
+
)
|
76 |
+
|
77 |
+
# Endpoint de SageMaker
|
78 |
+
model_kwargs = {
|
79 |
+
"max_new_tokens": 512,
|
80 |
+
"top_p": 0.8,
|
81 |
+
"temperature": 0.8,
|
82 |
+
"repetition_penalty": 1.0,
|
83 |
+
}
|
84 |
+
|
85 |
+
|
86 |
+
class ContentHandler(LLMContentHandler):
|
87 |
+
content_type = "application/json"
|
88 |
+
accepts = "application/json"
|
89 |
+
|
90 |
+
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
|
91 |
+
input_str = json.dumps(
|
92 |
+
# Template de prompt para Mistral, ver https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
93 |
+
{"inputs": f"<s>[INST] {prompt} [/INST]", "parameters": {**model_kwargs}}
|
94 |
+
)
|
95 |
+
return input_str.encode("utf-8")
|
96 |
+
|
97 |
+
def transform_output(self, output: bytes) -> str:
|
98 |
+
response_json = json.loads(output.read().decode("utf-8"))
|
99 |
+
splits = response_json[0]["generated_text"].split("[/INST] ")
|
100 |
+
return splits[1]
|
101 |
+
|
102 |
+
|
103 |
+
content_handler = ContentHandler()
|
104 |
+
|
105 |
+
llm = SagemakerEndpoint(
|
106 |
+
endpoint_name="mistral-langchain",
|
107 |
+
model_kwargs=model_kwargs,
|
108 |
+
content_handler=content_handler,
|
109 |
+
client=sagemaker_client,
|
110 |
+
)
|
111 |
+
|
112 |
+
chain = RetrievalQA.from_chain_type(
|
113 |
+
llm=llm,
|
114 |
+
chain_type="stuff",
|
115 |
+
retriever=retriever,
|
116 |
+
chain_type_kwargs={"prompt": prompt},
|
117 |
+
)
|
118 |
+
|
119 |
+
question = translate_client.translate_text(
|
120 |
+
Text=pregunta,
|
121 |
+
SourceLanguageCode="es",
|
122 |
+
TargetLanguageCode="en",
|
123 |
+
Settings={
|
124 |
+
"Formality": "FORMAL",
|
125 |
+
},
|
126 |
+
).get("TranslatedText")
|
127 |
+
|
128 |
+
answer = chain.run({"query": question})
|
129 |
+
respuesta = translate_client.translate_text(
|
130 |
+
Text=answer,
|
131 |
+
SourceLanguageCode="en",
|
132 |
+
TargetLanguageCode="es",
|
133 |
+
Settings={
|
134 |
+
"Formality": "FORMAL",
|
135 |
+
},
|
136 |
+
).get("TranslatedText")
|
137 |
+
|
138 |
+
st.write(respuesta)
|
credentials.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"aws_access_key_id": "AKIAY3BYT5UDHZK53NNX",
|
3 |
+
"aws_secret_access_key": "x2+4JKu1R5WITac1dd/ThVq6WEhPfsnQgG96vjJZ"
|
4 |
+
}
|
deployment/QA with OCR deployment.ipynb
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "4e7adbd9",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# LLM Challenge: QA Over Documents with OCR Integration\n",
|
9 |
+
"\n",
|
10 |
+
"Para resolver el desafío usaré RAG (Retrieval Augmented Generation). La idea general de RAG es que el modelo recupera documentos contextuales de un conjunto de datos externo como parte de su ejecución, estos documentos contextuales se utilizan junto con la entrada original para producir la salida.\n",
|
11 |
+
"\n",
|
12 |
+
"![RAG](https://huggingface.co/blog/assets/12_ray_rag/rag_gif.gif \"RAG\")\n",
|
13 |
+
"\n",
|
14 |
+
"Para este caso los documentos serán particiones del contrato (aunque fácilmente podemos agregar más documentos) que nos ayudarán a agregar información contextual para que una LLM pueda contestar las preguntas que le hagamos."
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 2,
|
20 |
+
"id": "70669b3a",
|
21 |
+
"metadata": {},
|
22 |
+
"outputs": [],
|
23 |
+
"source": [
|
24 |
+
"! pip install sagemaker==2.204.0 amazon-textract-caller==0.2.1 amazon-textract-textractor==1.6.1 pypdf==4.0.0 --quiet\n",
|
25 |
+
"! pip install langchain==0.1.3 sentence-transformers==2.2.2 faiss-cpu==1.7.4 --quiet"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 3,
|
31 |
+
"id": "4b878325",
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [
|
34 |
+
{
|
35 |
+
"name": "stdout",
|
36 |
+
"output_type": "stream",
|
37 |
+
"text": [
|
38 |
+
"sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n",
|
39 |
+
"sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml\n"
|
40 |
+
]
|
41 |
+
}
|
42 |
+
],
|
43 |
+
"source": [
|
44 |
+
"import boto3\n",
|
45 |
+
"import json\n",
|
46 |
+
"import sagemaker\n",
|
47 |
+
"from typing import Dict\n",
|
48 |
+
"\n",
|
49 |
+
"from langchain import LLMChain\n",
|
50 |
+
"from langchain.docstore.document import Document\n",
|
51 |
+
"from langchain.prompts import PromptTemplate\n",
|
52 |
+
"from langchain.llms import SagemakerEndpoint\n",
|
53 |
+
"from langchain.llms.sagemaker_endpoint import LLMContentHandler\n",
|
54 |
+
"from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri\n",
|
55 |
+
"\n",
|
56 |
+
"from langchain.document_loaders import AmazonTextractPDFLoader\n",
|
57 |
+
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
58 |
+
"from langchain.embeddings import HuggingFaceEmbeddings\n",
|
59 |
+
"from langchain.vectorstores import FAISS\n",
|
60 |
+
"from langchain.chains import RetrievalQA"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": 4,
|
66 |
+
"id": "674468f1",
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"with open('credentials.json') as file:\n",
|
71 |
+
" credentials = json.load(file)"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "markdown",
|
76 |
+
"id": "01002738",
|
77 |
+
"metadata": {},
|
78 |
+
"source": [
|
79 |
+
"## Despliegue de un LLM en SageMaker\n",
|
80 |
+
"\n",
|
81 |
+
"Hay una gran cantidad de [modelos](https://huggingface.co/models?pipeline_tag=text-generation) que podemos usar para la tarea. Inicialmente estaba usando algunos modelos multilingüísticos que incluían español, como [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1), pero los resultados que obtenía eran relativamente malos. Por lo tanto, decidí usar un modelo más robusto que estuviera entrenado únicamente en inglés y usar traducción automática para construir la base de datos vectorial necesaria y a la hora de hacer inferencia. Esta estrategia de traducir en general supera a los modelos multilingüísticos cuando se trata con [lenguajes con menos recursos](https://arxiv.org/abs/2311.09404).\n",
|
82 |
+
"\n",
|
83 |
+
"Particularmente elegí [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) para desplegar, el cual se desempeña bien en una multitud de tareas."
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": 5,
|
89 |
+
"id": "7fe712d7",
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [
|
92 |
+
{
|
93 |
+
"name": "stdout",
|
94 |
+
"output_type": "stream",
|
95 |
+
"text": [
|
96 |
+
"---------!"
|
97 |
+
]
|
98 |
+
}
|
99 |
+
],
|
100 |
+
"source": [
|
101 |
+
"# role = sagemaker.get_execution_role()\n",
|
102 |
+
"\n",
|
103 |
+
"iam = boto3.client(\n",
|
104 |
+
" service_name='iam',\n",
|
105 |
+
" region_name='us-east-1',\n",
|
106 |
+
" aws_access_key_id=credentials['aws_access_key_id'],\n",
|
107 |
+
" aws_secret_access_key=credentials['aws_secret_access_key']\n",
|
108 |
+
")\n",
|
109 |
+
"role = iam.get_role(RoleName='AmazonSageMakerServiceCatalogProductsUseRole')['Role']['Arn']\n",
|
110 |
+
"\n",
|
111 |
+
"\n",
|
112 |
+
"hub = {\n",
|
113 |
+
"\t'HF_MODEL_ID':'mistralai/Mistral-7B-Instruct-v0.1',\n",
|
114 |
+
"\t'SM_NUM_GPUS': '1'\n",
|
115 |
+
"}\n",
|
116 |
+
"\n",
|
117 |
+
"huggingface_model = HuggingFaceModel(\n",
|
118 |
+
"\timage_uri=get_huggingface_llm_image_uri(\"huggingface\",version=\"1.1.0\"),\n",
|
119 |
+
"\tenv=hub,\n",
|
120 |
+
"\trole=role \n",
|
121 |
+
")\n",
|
122 |
+
"\n",
|
123 |
+
"predictor = huggingface_model.deploy(\n",
|
124 |
+
" endpoint_name=\"mistral-langchain\",\n",
|
125 |
+
"\tinitial_instance_count=1,\n",
|
126 |
+
"\tinstance_type=\"ml.g5.2xlarge\",\n",
|
127 |
+
"\tcontainer_startup_health_check_timeout=300,\n",
|
128 |
+
")"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 7,
|
134 |
+
"id": "f1ead13e",
|
135 |
+
"metadata": {},
|
136 |
+
"outputs": [
|
137 |
+
{
|
138 |
+
"data": {
|
139 |
+
"text/plain": [
|
140 |
+
"[{'generated_text': \"<s>[INST] What is your favourite food? [/INST] My favorite food is pizza. It's a versatile dish that can be customized to suit\"}]"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
"execution_count": 7,
|
144 |
+
"metadata": {},
|
145 |
+
"output_type": "execute_result"
|
146 |
+
}
|
147 |
+
],
|
148 |
+
"source": [
|
149 |
+
"predictor.predict({\n",
|
150 |
+
"\t\"inputs\": \"<s>[INST] What is your favourite food? [/INST]\",\n",
|
151 |
+
"})"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": 7,
|
157 |
+
"id": "9736ae5b",
|
158 |
+
"metadata": {},
|
159 |
+
"outputs": [],
|
160 |
+
"source": [
|
161 |
+
"# help(huggingface_model.deploy)"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"cell_type": "code",
|
166 |
+
"execution_count": 8,
|
167 |
+
"id": "42020dd9",
|
168 |
+
"metadata": {},
|
169 |
+
"outputs": [
|
170 |
+
{
|
171 |
+
"data": {
|
172 |
+
"text/plain": [
|
173 |
+
"'mistral-langchain'"
|
174 |
+
]
|
175 |
+
},
|
176 |
+
"execution_count": 8,
|
177 |
+
"metadata": {},
|
178 |
+
"output_type": "execute_result"
|
179 |
+
}
|
180 |
+
],
|
181 |
+
"source": [
|
182 |
+
"endpoint_name = predictor.endpoint_name\n",
|
183 |
+
"endpoint_name"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "markdown",
|
188 |
+
"id": "7d1f55d2",
|
189 |
+
"metadata": {},
|
190 |
+
"source": [
|
191 |
+
"## OCR para extraer texto del pdf\n",
|
192 |
+
"\n",
|
193 |
+
"Debido a que el contrato es algo largo se debe particionar en pedazos más pequeños para poder vectorizar cada uno y poder usarlos posteriormente para RAG. Ya que hice el despliegue del LLM en AWS usaré la solución que ya tiene AWS para hacer OCR de documentos, el cual es [AWS Textract](https://aws.amazon.com/textract/). No me enfoqué tanto en la parte de OCR ya que me interesaba más completar el pipeline para el RAG. En esta parte es necesario traducir el contrato a inglés para posteriormente usar el LLM."
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": 10,
|
199 |
+
"id": "809ba42f",
|
200 |
+
"metadata": {},
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"uri = \"s3://sagemaker-us-east-1-607856618758/arkham-challenge/CONTRATO_AP000000718.pdf\"\n",
|
204 |
+
"\n",
|
205 |
+
"textract_client = boto3.client(\n",
|
206 |
+
" service_name='textract',\n",
|
207 |
+
" region_name='us-east-1',\n",
|
208 |
+
" aws_access_key_id=credentials['aws_access_key_id'],\n",
|
209 |
+
" aws_secret_access_key=credentials['aws_secret_access_key']\n",
|
210 |
+
")\n",
|
211 |
+
"translate_client = boto3.client(\n",
|
212 |
+
" service_name='translate',\n",
|
213 |
+
" region_name='us-east-1',\n",
|
214 |
+
" aws_access_key_id=credentials['aws_access_key_id'],\n",
|
215 |
+
" aws_secret_access_key=credentials['aws_secret_access_key']\n",
|
216 |
+
")"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": 12,
|
222 |
+
"id": "e8302f01",
|
223 |
+
"metadata": {},
|
224 |
+
"outputs": [
|
225 |
+
{
|
226 |
+
"name": "stdout",
|
227 |
+
"output_type": "stream",
|
228 |
+
"text": [
|
229 |
+
"Longitud del documento: 30 páginas, chunks: 453\n"
|
230 |
+
]
|
231 |
+
}
|
232 |
+
],
|
233 |
+
"source": [
|
234 |
+
"# OCR del contrato y división en pedazos más pequeños\n",
|
235 |
+
"loader = AmazonTextractPDFLoader(uri, client=textract_client)\n",
|
236 |
+
"document = loader.load()\n",
|
237 |
+
"\n",
|
238 |
+
"splitter = RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=64)\n",
|
239 |
+
"chunks = splitter.split_documents(document)\n",
|
240 |
+
" \n",
|
241 |
+
"for chunk in chunks:\n",
|
242 |
+
" chunk.page_content = translate_client.translate_text(\n",
|
243 |
+
" Text=chunk.page_content, \n",
|
244 |
+
" SourceLanguageCode=\"es\", \n",
|
245 |
+
" TargetLanguageCode='en',\n",
|
246 |
+
" Settings={\n",
|
247 |
+
" 'Formality': 'FORMAL',\n",
|
248 |
+
" }\n",
|
249 |
+
" ).get('TranslatedText')\n",
|
250 |
+
" \n",
|
251 |
+
"print(f\"Longitud del documento: {len(document)} páginas, chunks: {len(chunks)}\")"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "markdown",
|
256 |
+
"id": "4e2cc2f9",
|
257 |
+
"metadata": {},
|
258 |
+
"source": [
|
259 |
+
"## Embeddings de los documentos y base de datos vectorial\n",
|
260 |
+
"\n",
|
261 |
+
"Ahora necesitamos hacer una base de datos vectorial, hay una gran cantidad de [modelos](https://huggingface.co/models?pipeline_tag=feature-extraction) que podemos usar, no noté una gran diferencia en las respuestas finales usando distintos modelos, por lo que decidí usar un modelo relativamente pequeño [intfloat/multilingual-e5-small](https://huggingface.co/intfloat/multilingual-e5-small)."
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "code",
|
266 |
+
"execution_count": 13,
|
267 |
+
"id": "afc6629d",
|
268 |
+
"metadata": {},
|
269 |
+
"outputs": [],
|
270 |
+
"source": [
|
271 |
+
"# embedding_model_id = \"BAAI/bge-small-en-v1.5\"\n",
|
272 |
+
"# embedding_model_id = \"sentence-transformers/distiluse-base-multilingual-cased-v2\"\n",
|
273 |
+
"# embedding_model_id = \"intfloat/e5-mistral-7b-instruct\"\n",
|
274 |
+
"embedding_model_id = \"intfloat/multilingual-e5-small\"\n",
|
275 |
+
"\n",
|
276 |
+
"embeddings = HuggingFaceEmbeddings(\n",
|
277 |
+
" model_name=embedding_model_id,\n",
|
278 |
+
")"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 16,
|
284 |
+
"id": "9f66db6f",
|
285 |
+
"metadata": {},
|
286 |
+
"outputs": [],
|
287 |
+
"source": [
|
288 |
+
"embeddings_db = FAISS.from_documents(chunks, embeddings)\n",
|
289 |
+
"\n",
|
290 |
+
"# Base de datos vectorial que usaremos\n",
|
291 |
+
"embeddings_db.save_local(\"faiss_index\")"
|
292 |
+
]
|
293 |
+
},
|
294 |
+
{
|
295 |
+
"cell_type": "markdown",
|
296 |
+
"id": "9ad5efdf",
|
297 |
+
"metadata": {},
|
298 |
+
"source": [
|
299 |
+
"## RAG\n",
|
300 |
+
"\n",
|
301 |
+
"Ahora tenemos todo lo necesario para realizar RAG con el LLM que desplegamos y la base de datos vectorial con el contrato."
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"cell_type": "code",
|
306 |
+
"execution_count": 24,
|
307 |
+
"id": "828b09fa",
|
308 |
+
"metadata": {},
|
309 |
+
"outputs": [],
|
310 |
+
"source": [
|
311 |
+
"embeddings_db = FAISS.load_local(\"faiss_index\", embeddings)\n",
|
312 |
+
"retriever = embeddings_db.as_retriever(search_kwargs={\"k\": 2})"
|
313 |
+
]
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"cell_type": "code",
|
317 |
+
"execution_count": 25,
|
318 |
+
"id": "ca0d9508",
|
319 |
+
"metadata": {},
|
320 |
+
"outputs": [],
|
321 |
+
"source": [
|
322 |
+
"model_kwargs = {\n",
|
323 |
+
" \"max_new_tokens\": 1024, \n",
|
324 |
+
" \"top_p\": 0.8, \n",
|
325 |
+
" \"temperature\": 0.1,\n",
|
326 |
+
" \"repetition_penalty\": 1.0\n",
|
327 |
+
"}\n",
|
328 |
+
"\n",
|
329 |
+
"class ContentHandler(LLMContentHandler):\n",
|
330 |
+
" content_type = \"application/json\"\n",
|
331 |
+
" accepts = \"application/json\"\n",
|
332 |
+
"\n",
|
333 |
+
" def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
|
334 |
+
" input_str = json.dumps(\n",
|
335 |
+
" # Template de prompt para Mistral, ver https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1\n",
|
336 |
+
" {\"inputs\": f\"<s>[INST] {prompt} [/INST]\", \"parameters\": {**model_kwargs}}\n",
|
337 |
+
" )\n",
|
338 |
+
" return input_str.encode(\"utf-8\")\n",
|
339 |
+
"\n",
|
340 |
+
" def transform_output(self, output: bytes) -> str:\n",
|
341 |
+
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
342 |
+
" splits = response_json[0][\"generated_text\"].split(\"[/INST] \")\n",
|
343 |
+
" return splits[1]\n",
|
344 |
+
"\n",
|
345 |
+
"content_handler = ContentHandler()\n",
|
346 |
+
"\n",
|
347 |
+
"sm_client = boto3.client(\n",
|
348 |
+
" service_name='sagemaker-runtime',\n",
|
349 |
+
" region_name='us-east-1',\n",
|
350 |
+
" aws_access_key_id=credentials['aws_access_key_id'],\n",
|
351 |
+
" aws_secret_access_key=credentials['aws_secret_access_key']\n",
|
352 |
+
")\n",
|
353 |
+
"\n",
|
354 |
+
"llm = SagemakerEndpoint(\n",
|
355 |
+
" endpoint_name=endpoint_name,\n",
|
356 |
+
" model_kwargs=model_kwargs,\n",
|
357 |
+
" content_handler=content_handler,\n",
|
358 |
+
" client=sm_client,\n",
|
359 |
+
")"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "code",
|
364 |
+
"execution_count": 26,
|
365 |
+
"id": "c728567d",
|
366 |
+
"metadata": {},
|
367 |
+
"outputs": [],
|
368 |
+
"source": [
|
369 |
+
"prompt_template = \"\"\"\n",
|
370 |
+
"Please answer the question below, using only the context below. \n",
|
371 |
+
"Don't invent facts. If you can't provide a factual answer, say you don't know what the answer is.\n",
|
372 |
+
"\n",
|
373 |
+
"question: {question}\n",
|
374 |
+
"\n",
|
375 |
+
"context: {context}\n",
|
376 |
+
"\"\"\"\n",
|
377 |
+
"\n",
|
378 |
+
"prompt = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])"
|
379 |
+
]
|
380 |
+
},
|
381 |
+
{
|
382 |
+
"cell_type": "code",
|
383 |
+
"execution_count": 27,
|
384 |
+
"id": "15ad36c9",
|
385 |
+
"metadata": {},
|
386 |
+
"outputs": [],
|
387 |
+
"source": [
|
388 |
+
"chain = RetrievalQA.from_chain_type(\n",
|
389 |
+
" llm=llm, \n",
|
390 |
+
" chain_type=\"stuff\",\n",
|
391 |
+
" retriever=retriever, \n",
|
392 |
+
" chain_type_kwargs = {\"prompt\": prompt}\n",
|
393 |
+
")"
|
394 |
+
]
|
395 |
+
},
|
396 |
+
{
|
397 |
+
"cell_type": "markdown",
|
398 |
+
"id": "4063011b",
|
399 |
+
"metadata": {},
|
400 |
+
"source": [
|
401 |
+
"### QA con RAG"
|
402 |
+
]
|
403 |
+
},
|
404 |
+
{
|
405 |
+
"cell_type": "code",
|
406 |
+
"execution_count": 30,
|
407 |
+
"id": "7c0d8455",
|
408 |
+
"metadata": {},
|
409 |
+
"outputs": [
|
410 |
+
{
|
411 |
+
"name": "stdout",
|
412 |
+
"output_type": "stream",
|
413 |
+
"text": [
|
414 |
+
"Answer: Based on the provided context, the depositary is Oscar Alberto Islas Mendoza. \n",
|
415 |
+
"\n",
|
416 |
+
"Respuesta: Según el contexto proporcionado, el depositario es Oscar Alberto Islas Mendoza.\n"
|
417 |
+
]
|
418 |
+
}
|
419 |
+
],
|
420 |
+
"source": [
|
421 |
+
"pregunta = \"¿Quién es el depositario?\"\n",
|
422 |
+
"question = translate_client.translate_text(\n",
|
423 |
+
" Text=pregunta, \n",
|
424 |
+
" SourceLanguageCode=\"es\", \n",
|
425 |
+
" TargetLanguageCode='en',\n",
|
426 |
+
" Settings={\n",
|
427 |
+
" 'Formality': 'FORMAL',\n",
|
428 |
+
" }\n",
|
429 |
+
").get('TranslatedText')\n",
|
430 |
+
"\n",
|
431 |
+
"answer = chain.run({\"query\": question})\n",
|
432 |
+
"respuesta = translate_client.translate_text(\n",
|
433 |
+
" Text=answer, \n",
|
434 |
+
" SourceLanguageCode=\"en\", \n",
|
435 |
+
" TargetLanguageCode='es',\n",
|
436 |
+
" Settings={\n",
|
437 |
+
" 'Formality': 'FORMAL',\n",
|
438 |
+
" }\n",
|
439 |
+
").get('TranslatedText')\n",
|
440 |
+
"\n",
|
441 |
+
"\n",
|
442 |
+
"print(f\"Answer: {answer} \\n\")\n",
|
443 |
+
"print(f\"Respuesta: {respuesta}\")"
|
444 |
+
]
|
445 |
+
},
|
446 |
+
{
|
447 |
+
"cell_type": "code",
|
448 |
+
"execution_count": 35,
|
449 |
+
"id": "357232b7",
|
450 |
+
"metadata": {},
|
451 |
+
"outputs": [
|
452 |
+
{
|
453 |
+
"name": "stdout",
|
454 |
+
"output_type": "stream",
|
455 |
+
"text": [
|
456 |
+
"Answer: Replacement value refers to the commercial value of equipment on a given date, plus the current Value Added Tax (\"VAT\"). It is used to determine the amount of advance payment that is required for the equipment. The difference between the replacement value and the periodic payments, if any, is calculated to determine the amount of any damages that may be owed. \n",
|
457 |
+
"\n",
|
458 |
+
"Respuesta: El valor de reposición se refiere al valor comercial del equipo en una fecha determinada, más el impuesto sobre el valor añadido («IVA») actual. Se utiliza para determinar el importe de anticipo que se requiere para el equipo. La diferencia entre el valor de reposición y los pagos periódicos, si los hubiera, se calcula para determinar el importe de cualquier daño que pueda adeudarse.\n"
|
459 |
+
]
|
460 |
+
}
|
461 |
+
],
|
462 |
+
"source": [
|
463 |
+
"pregunta = \"¿Qué es valor de reposición?\"\n",
|
464 |
+
"question = translate_client.translate_text(\n",
|
465 |
+
" Text=pregunta, \n",
|
466 |
+
" SourceLanguageCode=\"es\", \n",
|
467 |
+
" TargetLanguageCode='en',\n",
|
468 |
+
" Settings={\n",
|
469 |
+
" 'Formality': 'FORMAL',\n",
|
470 |
+
" }\n",
|
471 |
+
").get('TranslatedText')\n",
|
472 |
+
"\n",
|
473 |
+
"answer = chain.run({\"query\": question})\n",
|
474 |
+
"respuesta = translate_client.translate_text(\n",
|
475 |
+
" Text=answer, \n",
|
476 |
+
" SourceLanguageCode=\"en\", \n",
|
477 |
+
" TargetLanguageCode='es',\n",
|
478 |
+
" Settings={\n",
|
479 |
+
" 'Formality': 'FORMAL',\n",
|
480 |
+
" }\n",
|
481 |
+
").get('TranslatedText')\n",
|
482 |
+
"\n",
|
483 |
+
"\n",
|
484 |
+
"print(f\"Answer: {answer} \\n\")\n",
|
485 |
+
"print(f\"Respuesta: {respuesta}\")"
|
486 |
+
]
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"cell_type": "code",
|
490 |
+
"execution_count": 28,
|
491 |
+
"id": "3dc37600",
|
492 |
+
"metadata": {},
|
493 |
+
"outputs": [],
|
494 |
+
"source": [
|
495 |
+
"# predictor.delete_model()\n",
|
496 |
+
"# predictor.delete_endpoint()"
|
497 |
+
]
|
498 |
+
}
|
499 |
+
],
|
500 |
+
"metadata": {
|
501 |
+
"kernelspec": {
|
502 |
+
"display_name": "conda_pytorch_p310",
|
503 |
+
"language": "python",
|
504 |
+
"name": "conda_pytorch_p310"
|
505 |
+
},
|
506 |
+
"language_info": {
|
507 |
+
"codemirror_mode": {
|
508 |
+
"name": "ipython",
|
509 |
+
"version": 3
|
510 |
+
},
|
511 |
+
"file_extension": ".py",
|
512 |
+
"mimetype": "text/x-python",
|
513 |
+
"name": "python",
|
514 |
+
"nbconvert_exporter": "python",
|
515 |
+
"pygments_lexer": "ipython3",
|
516 |
+
"version": "3.10.13"
|
517 |
+
}
|
518 |
+
},
|
519 |
+
"nbformat": 4,
|
520 |
+
"nbformat_minor": 5
|
521 |
+
}
|
deployment/credentials.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"aws_access_key_id": "AKIAY3BYT5UDHZK53NNX",
|
3 |
+
"aws_secret_access_key": "x2+4JKu1R5WITac1dd/ThVq6WEhPfsnQgG96vjJZ"
|
4 |
+
}
|
faiss_index/index.faiss
ADDED
Binary file (696 kB). View file
|
|
faiss_index/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd4ce77807d1b646571ce588db4200b1cd71297a114efa37e468a99ad9fae365
|
3 |
+
size 141016
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
boto3==1.34.19
|
2 |
+
faiss-cpu==1.7.4
|
3 |
+
langchain==0.1.3
|
4 |
+
sagemaker==2.204.0
|
5 |
+
streamlit==1.29.0
|
6 |
+
sentence-transformers==2.2.2
|
7 |
+
transformers==4.37.0
|