Spaces:
Running
Running
acumplid
commited on
Commit
•
a8bf50c
1
Parent(s):
a6487f4
base app
Browse files- .gitignore +5 -0
- README.md +6 -5
- app.py +266 -0
- handler.py +14 -0
- input_reader.py +22 -0
- rag.py +165 -0
- rag_image.jpg +0 -0
- requirements.txt +14 -0
- utils.py +33 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/venv
|
2 |
+
/venv/*
|
3 |
+
.env
|
4 |
+
__pycache__
|
5 |
+
__pycache__/*
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
-
title: MLhouse
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: MLhouse-RAG
|
3 |
+
emoji: 💻
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.24.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from gradio.components import Textbox, Button, Slider, Checkbox
|
4 |
+
from AinaTheme import theme
|
5 |
+
from urllib.error import HTTPError
|
6 |
+
|
7 |
+
from rag import RAG
|
8 |
+
from utils import setup
|
9 |
+
|
10 |
+
MAX_NEW_TOKENS = 700
|
11 |
+
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
|
12 |
+
|
13 |
+
setup()
|
14 |
+
|
15 |
+
|
16 |
+
rag = RAG(
|
17 |
+
hf_token=os.getenv("HF_TOKEN"),
|
18 |
+
embeddings_model=os.getenv("EMBEDDINGS"),
|
19 |
+
repo_name=os.getenv("REPO_NAME"),
|
20 |
+
)
|
21 |
+
|
22 |
+
# model_name=os.getenv("MODEL"),
|
23 |
+
# rerank_model=os.getenv("RERANK_MODEL"),
|
24 |
+
# rerank_number_contexts=int(os.getenv("RERANK_NUMBER_CONTEXTS"))
|
25 |
+
|
26 |
+
def generate(prompt, model_parameters):
|
27 |
+
try:
|
28 |
+
output, context, source = rag.get_response(prompt, model_parameters)
|
29 |
+
return output, context, source
|
30 |
+
except HTTPError as err:
|
31 |
+
if err.code == 400:
|
32 |
+
gr.Warning(
|
33 |
+
"The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET."
|
34 |
+
)
|
35 |
+
except:
|
36 |
+
gr.Warning(
|
37 |
+
"Inference endpoint is not available right now. Please try again later."
|
38 |
+
)
|
39 |
+
return None, None, None
|
40 |
+
|
41 |
+
|
42 |
+
def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
|
43 |
+
if input_.strip() == "":
|
44 |
+
gr.Warning("Not possible to inference an empty input")
|
45 |
+
return None
|
46 |
+
|
47 |
+
|
48 |
+
model_parameters = {
|
49 |
+
"NUM_CHUNKS": num_chunks,
|
50 |
+
"max_new_tokens": max_new_tokens,
|
51 |
+
"repetition_penalty": repetition_penalty,
|
52 |
+
"top_k": top_k,
|
53 |
+
"top_p": top_p,
|
54 |
+
"do_sample": do_sample,
|
55 |
+
"temperature": temperature
|
56 |
+
}
|
57 |
+
|
58 |
+
output, context, source = generate(input_, model_parameters)
|
59 |
+
sources_markup = ""
|
60 |
+
|
61 |
+
for url in source:
|
62 |
+
sources_markup += f'<a href="{url}" target="_blank">{url}</a><br>'
|
63 |
+
|
64 |
+
return output, sources_markup, context
|
65 |
+
# return output.strip(), sources_markup, context
|
66 |
+
|
67 |
+
|
68 |
+
def change_interactive(text):
|
69 |
+
if len(text) == 0:
|
70 |
+
return gr.update(interactive=True), gr.update(interactive=False)
|
71 |
+
return gr.update(interactive=True), gr.update(interactive=True)
|
72 |
+
|
73 |
+
|
74 |
+
def clear():
|
75 |
+
return (
|
76 |
+
None,
|
77 |
+
None,
|
78 |
+
None,
|
79 |
+
None,
|
80 |
+
gr.Slider(value=2.0),
|
81 |
+
gr.Slider(value=MAX_NEW_TOKENS),
|
82 |
+
gr.Slider(value=1.0),
|
83 |
+
gr.Slider(value=50),
|
84 |
+
gr.Slider(value=0.99),
|
85 |
+
gr.Checkbox(value=False),
|
86 |
+
gr.Slider(value=0.35),
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
def gradio_app():
|
91 |
+
with gr.Blocks(theme=theme) as demo:
|
92 |
+
with gr.Row():
|
93 |
+
with gr.Column(scale=0.1):
|
94 |
+
gr.Image("rag_image.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False)
|
95 |
+
with gr.Column():
|
96 |
+
gr.Markdown(
|
97 |
+
"""# Demo de Retrieval-Augmented Generation per documents legals
|
98 |
+
🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
|
99 |
+
en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta
|
100 |
+
fent servir només la informació existent en els documents del repositori.
|
101 |
+
|
102 |
+
🎯 **Objectiu:** Aquest és un demostrador amb la normativa vigent publicada al Diari Oficial de la Generalitat de Catalunya, en el
|
103 |
+
repositori del EADOP (Entitat Autònoma del Diari Oficial i de Publicacions). Aquesta versió explora prop de 2000 documents en català,
|
104 |
+
i genera la resposta fent servir el model Salamandra-7b-aligned-EADOP, el model BSC-LT/salamandra-7b-instruct alineat amb el dataset de alinia/EADOP-RAG-out-of-domain.
|
105 |
+
|
106 |
+
⚠️ **Advertencies**: Aquesta versió és experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte.
|
107 |
+
Si us plau, tingueu-ho en compte quan exploreu aquest recurs.
|
108 |
+
"""
|
109 |
+
)
|
110 |
+
with gr.Row(equal_height=True):
|
111 |
+
with gr.Column(variant="panel"):
|
112 |
+
input_ = Textbox(
|
113 |
+
lines=11,
|
114 |
+
label="Input",
|
115 |
+
placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
|
116 |
+
# value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
|
117 |
+
)
|
118 |
+
with gr.Row(variant="panel"):
|
119 |
+
clear_btn = Button(
|
120 |
+
"Clear",
|
121 |
+
)
|
122 |
+
submit_btn = Button("Submit", variant="primary", interactive=False)
|
123 |
+
|
124 |
+
with gr.Row(variant="panel"):
|
125 |
+
with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
|
126 |
+
num_chunks = Slider(
|
127 |
+
minimum=1,
|
128 |
+
maximum=6,
|
129 |
+
step=1,
|
130 |
+
value=2,
|
131 |
+
label="Number of chunks"
|
132 |
+
)
|
133 |
+
max_new_tokens = Slider(
|
134 |
+
minimum=50,
|
135 |
+
maximum=2000,
|
136 |
+
step=1,
|
137 |
+
value=MAX_NEW_TOKENS,
|
138 |
+
label="Max tokens"
|
139 |
+
)
|
140 |
+
repetition_penalty = Slider(
|
141 |
+
minimum=0.1,
|
142 |
+
maximum=2.0,
|
143 |
+
step=0.1,
|
144 |
+
value=1.0,
|
145 |
+
label="Repetition penalty"
|
146 |
+
)
|
147 |
+
top_k = Slider(
|
148 |
+
minimum=1,
|
149 |
+
maximum=100,
|
150 |
+
step=1,
|
151 |
+
value=50,
|
152 |
+
label="Top k"
|
153 |
+
)
|
154 |
+
top_p = Slider(
|
155 |
+
minimum=0.01,
|
156 |
+
maximum=0.99,
|
157 |
+
value=0.99,
|
158 |
+
label="Top p"
|
159 |
+
)
|
160 |
+
do_sample = Checkbox(
|
161 |
+
value=False,
|
162 |
+
label="Do sample"
|
163 |
+
)
|
164 |
+
temperature = Slider(
|
165 |
+
minimum=0.1,
|
166 |
+
maximum=1,
|
167 |
+
value=0.35,
|
168 |
+
label="Temperature"
|
169 |
+
)
|
170 |
+
|
171 |
+
parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
|
172 |
+
|
173 |
+
with gr.Column(variant="panel"):
|
174 |
+
output = Textbox(
|
175 |
+
lines=10,
|
176 |
+
label="Output",
|
177 |
+
interactive=False,
|
178 |
+
show_copy_button=True
|
179 |
+
)
|
180 |
+
with gr.Accordion("Sources and context:", open=False):
|
181 |
+
source_context = gr.Markdown(
|
182 |
+
label="Sources",
|
183 |
+
show_label=False,
|
184 |
+
)
|
185 |
+
with gr.Accordion("See full context evaluation:", open=False):
|
186 |
+
context_evaluation = gr.Markdown(
|
187 |
+
label="Full context",
|
188 |
+
show_label=False,
|
189 |
+
# interactive=False,
|
190 |
+
# autoscroll=False,
|
191 |
+
# show_copy_button=True
|
192 |
+
)
|
193 |
+
|
194 |
+
|
195 |
+
input_.change(
|
196 |
+
fn=change_interactive,
|
197 |
+
inputs=[input_],
|
198 |
+
outputs=[clear_btn, submit_btn],
|
199 |
+
api_name=False,
|
200 |
+
)
|
201 |
+
|
202 |
+
input_.change(
|
203 |
+
fn=None,
|
204 |
+
inputs=[input_],
|
205 |
+
api_name=False,
|
206 |
+
js="""(i, m) => {
|
207 |
+
document.getElementById('inputlenght').textContent = i.length + ' '
|
208 |
+
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
|
209 |
+
}""",
|
210 |
+
)
|
211 |
+
|
212 |
+
clear_btn.click(
|
213 |
+
fn=clear,
|
214 |
+
inputs=[],
|
215 |
+
outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
|
216 |
+
queue=False,
|
217 |
+
api_name=False
|
218 |
+
)
|
219 |
+
|
220 |
+
submit_btn.click(
|
221 |
+
fn=submit_input,
|
222 |
+
inputs=[input_]+ parameters_compontents,
|
223 |
+
outputs=[output, source_context, context_evaluation],
|
224 |
+
api_name="get-results"
|
225 |
+
)
|
226 |
+
|
227 |
+
with gr.Row():
|
228 |
+
with gr.Column(scale=0.5):
|
229 |
+
gr.Examples(
|
230 |
+
examples=[
|
231 |
+
["""Què és l'EADOP (Entitat Autònoma del Diari Oficial i de Publicacions)?"""],
|
232 |
+
],
|
233 |
+
inputs=input_,
|
234 |
+
outputs=[output, source_context, context_evaluation],
|
235 |
+
fn=submit_input,
|
236 |
+
)
|
237 |
+
gr.Examples(
|
238 |
+
examples=[
|
239 |
+
["""Què diu el decret sobre la senyalització de les begudes alcohòliques i el tabac a Catalunya?"""],
|
240 |
+
],
|
241 |
+
inputs=input_,
|
242 |
+
outputs=[output, source_context, context_evaluation],
|
243 |
+
fn=submit_input,
|
244 |
+
)
|
245 |
+
gr.Examples(
|
246 |
+
examples=[
|
247 |
+
["""Com es pot inscriure una persona al Registre de catalans i catalanes residents a l'exterior?"""],
|
248 |
+
],
|
249 |
+
inputs=input_,
|
250 |
+
outputs=[output, source_context, context_evaluation],
|
251 |
+
fn=submit_input,
|
252 |
+
)
|
253 |
+
gr.Examples(
|
254 |
+
examples=[
|
255 |
+
["""Quina és la finalitat del Servei Meterològic de Catalunya ?"""],
|
256 |
+
],
|
257 |
+
inputs=input_,
|
258 |
+
outputs=[output, source_context, context_evaluation],
|
259 |
+
fn=submit_input,
|
260 |
+
)
|
261 |
+
|
262 |
+
demo.launch(show_api=True)
|
263 |
+
|
264 |
+
|
265 |
+
if __name__ == "__main__":
|
266 |
+
gradio_app()
|
handler.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
class ContentHandler():
|
4 |
+
content_type = "application/json"
|
5 |
+
accepts = "application/json"
|
6 |
+
|
7 |
+
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
|
8 |
+
input_str = json.dumps({'inputs': prompt, 'parameters': model_kwargs})
|
9 |
+
return input_str.encode('utf-8')
|
10 |
+
|
11 |
+
def transform_output(self, output: bytes) -> str:
|
12 |
+
response_json = json.loads(output.read().decode("utf-8"))
|
13 |
+
return response_json[0]["generated_text"]
|
14 |
+
|
input_reader.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from llama_index.core.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
|
4 |
+
from llama_index.core.readers import SimpleDirectoryReader
|
5 |
+
from llama_index.core.schema import Document
|
6 |
+
from llama_index.core import Settings
|
7 |
+
|
8 |
+
|
9 |
+
class InputReader:
|
10 |
+
def __init__(self, input_dir: str) -> None:
|
11 |
+
self.reader = SimpleDirectoryReader(input_dir=input_dir)
|
12 |
+
|
13 |
+
def parse_documents(
|
14 |
+
self,
|
15 |
+
show_progress: bool = True,
|
16 |
+
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
17 |
+
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
18 |
+
) -> List[Document]:
|
19 |
+
Settings.chunk_size = chunk_size
|
20 |
+
Settings.chunk_overlap = chunk_overlap
|
21 |
+
documents = self.reader.load_data(show_progress=show_progress)
|
22 |
+
return documents
|
rag.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import requests
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
import torch
|
6 |
+
from openai import OpenAI
|
7 |
+
from huggingface_hub import snapshot_download
|
8 |
+
|
9 |
+
from langchain_community.vectorstores import FAISS
|
10 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
11 |
+
|
12 |
+
|
13 |
+
class RAG:
|
14 |
+
NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
# Download the vectorstore from Hugging Face Hub
|
19 |
+
|
20 |
+
def __init__(self, hf_token, embeddings_model, repo_name):
|
21 |
+
|
22 |
+
vectorstore = snapshot_download(repo_name)
|
23 |
+
|
24 |
+
|
25 |
+
# self.model_name = model_name
|
26 |
+
self.hf_token = hf_token
|
27 |
+
# self.rerank_model = rerank_model
|
28 |
+
# self.rerank_number_contexts = rerank_number_contexts
|
29 |
+
|
30 |
+
# load vectore store
|
31 |
+
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
|
32 |
+
self.vectore_store = FAISS.load_local(vectorstore, embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True)
|
33 |
+
|
34 |
+
logging.info("RAG loaded!")
|
35 |
+
|
36 |
+
# def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
|
37 |
+
# """
|
38 |
+
# Rerank the contexts based on their relevance to the given instruction.
|
39 |
+
# """
|
40 |
+
|
41 |
+
# rerank_model = self.rerank_model
|
42 |
+
|
43 |
+
|
44 |
+
# tokenizer = AutoTokenizer.from_pretrained(rerank_model)
|
45 |
+
# model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
|
46 |
+
|
47 |
+
# def get_score(query, passage):
|
48 |
+
# """Calculate the relevance score of a passage with respect to a query."""
|
49 |
+
|
50 |
+
|
51 |
+
# inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
|
52 |
+
|
53 |
+
|
54 |
+
# with torch.no_grad():
|
55 |
+
# outputs = model(**inputs)
|
56 |
+
|
57 |
+
|
58 |
+
# logits = outputs.logits
|
59 |
+
|
60 |
+
|
61 |
+
# score = logits.view(-1, ).float()
|
62 |
+
|
63 |
+
|
64 |
+
# return score
|
65 |
+
|
66 |
+
# scores = [get_score(instruction, c[0].page_content) for c in contexts]
|
67 |
+
# combined = list(zip(contexts, scores))
|
68 |
+
# sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
|
69 |
+
# sorted_texts, _ = zip(*sorted_combined)
|
70 |
+
|
71 |
+
# return sorted_texts[:number_of_contexts]
|
72 |
+
|
73 |
+
def get_context(self, instruction, number_of_contexts=2):
|
74 |
+
"""Retrieve the most relevant contexts for a given instruction."""
|
75 |
+
documentos = self.vectore_store.similarity_search_with_score(instruction, k=4)
|
76 |
+
|
77 |
+
# documentos = self.rerank_contexts(instruction, documentos, number_of_contexts=number_of_contexts)
|
78 |
+
|
79 |
+
print("Reranked documents")
|
80 |
+
return documentos
|
81 |
+
|
82 |
+
def predict_dolly(self, instruction, context, model_parameters):
|
83 |
+
|
84 |
+
api_key = os.getenv("HF_TOKEN")
|
85 |
+
|
86 |
+
|
87 |
+
headers = {
|
88 |
+
"Accept" : "application/json",
|
89 |
+
"Authorization": f"Bearer {api_key}",
|
90 |
+
"Content-Type": "application/json"
|
91 |
+
}
|
92 |
+
|
93 |
+
query = f"### Instruction\n{instruction}\n\n### Context\n{context}\n\n### Answer\n "
|
94 |
+
#prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
|
95 |
+
|
96 |
+
|
97 |
+
payload = {
|
98 |
+
"inputs": query,
|
99 |
+
"parameters": model_parameters
|
100 |
+
}
|
101 |
+
|
102 |
+
response = requests.post(self.model_name, headers=headers, json=payload)
|
103 |
+
|
104 |
+
return response.json()[0]["generated_text"].split("###")[-1][8:]
|
105 |
+
|
106 |
+
def predict_completion(self, instruction, context, model_parameters):
|
107 |
+
|
108 |
+
client = OpenAI(
|
109 |
+
base_url=os.getenv("MODEL"),
|
110 |
+
api_key=os.getenv("HF_TOKEN")
|
111 |
+
)
|
112 |
+
|
113 |
+
query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
|
114 |
+
|
115 |
+
chat_completion = client.chat.completions.create(
|
116 |
+
model="tgi",
|
117 |
+
messages=[
|
118 |
+
{"role": "user", "content": instruction}
|
119 |
+
],
|
120 |
+
temperature=model_parameters["temperature"],
|
121 |
+
max_tokens=model_parameters["max_new_tokens"],
|
122 |
+
stream=False,
|
123 |
+
stop=["<|im_end|>"],
|
124 |
+
extra_body = {
|
125 |
+
"presence_penalty": model_parameters["repetition_penalty"] - 2,
|
126 |
+
"do_sample": False
|
127 |
+
}
|
128 |
+
)
|
129 |
+
|
130 |
+
response = chat_completion.choices[0].message.content
|
131 |
+
|
132 |
+
return response
|
133 |
+
|
134 |
+
|
135 |
+
def beautiful_context(self, docs):
|
136 |
+
|
137 |
+
text_context = ""
|
138 |
+
|
139 |
+
full_context = ""
|
140 |
+
source_context = []
|
141 |
+
for doc in docs:
|
142 |
+
text_context += doc[0].page_content
|
143 |
+
full_context += doc[0].page_content + "\n"
|
144 |
+
full_context += doc[0].metadata["Títol de la norma"] + "\n\n"
|
145 |
+
full_context += doc[0].metadata["url"] + "\n\n"
|
146 |
+
source_context.append(doc[0].metadata["url"])
|
147 |
+
|
148 |
+
return text_context, full_context, source_context
|
149 |
+
|
150 |
+
def get_response(self, prompt: str, model_parameters: dict) -> str:
|
151 |
+
try:
|
152 |
+
docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
|
153 |
+
text_context, full_context, source = self.beautiful_context(docs)
|
154 |
+
|
155 |
+
del model_parameters["NUM_CHUNKS"]
|
156 |
+
|
157 |
+
# response = self.predict_completion(prompt, text_context, model_parameters)
|
158 |
+
response = "Output"
|
159 |
+
|
160 |
+
if not response:
|
161 |
+
return self.NO_ANSWER_MESSAGE
|
162 |
+
|
163 |
+
return response, full_context, source
|
164 |
+
except Exception as err:
|
165 |
+
print(err)
|
rag_image.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.29.0
|
2 |
+
huggingface-hub==0.23.4
|
3 |
+
python-dotenv==1.0.0
|
4 |
+
llama-index==0.10.14
|
5 |
+
llama-index-embeddings-huggingface==0.2.2
|
6 |
+
llama-index-llms-huggingface==0.2.4
|
7 |
+
sentence-transformers==2.7.0
|
8 |
+
langchain
|
9 |
+
faiss-cpu
|
10 |
+
aina-gradio-theme==2.3
|
11 |
+
|
12 |
+
langchain-community==0.2.1
|
13 |
+
langchain-core==0.2.1
|
14 |
+
openai==1.35.12
|
utils.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
|
6 |
+
|
7 |
+
from rag import RAG
|
8 |
+
|
9 |
+
USER_INPUT = 100
|
10 |
+
|
11 |
+
|
12 |
+
def setup():
|
13 |
+
load_dotenv()
|
14 |
+
warnings.filterwarnings("ignore")
|
15 |
+
|
16 |
+
logging.addLevelName(USER_INPUT, "USER_INPUT")
|
17 |
+
logging.basicConfig(format="[%(levelname)s]: %(message)s", level=logging.INFO)
|
18 |
+
|
19 |
+
|
20 |
+
def interactive(model: RAG):
|
21 |
+
logging.info("Write `exit` when you want to stop the model.")
|
22 |
+
print()
|
23 |
+
|
24 |
+
query = ""
|
25 |
+
while query.lower() != "exit":
|
26 |
+
logging.log(USER_INPUT, "Write the query or `exit`:")
|
27 |
+
query = input()
|
28 |
+
|
29 |
+
if query.lower() == "exit":
|
30 |
+
break
|
31 |
+
|
32 |
+
response = model.get_response(query)
|
33 |
+
print(response, end="\n\n")
|