acumplid commited on
Commit
a8bf50c
1 Parent(s): a6487f4
Files changed (9) hide show
  1. .gitignore +5 -0
  2. README.md +6 -5
  3. app.py +266 -0
  4. handler.py +14 -0
  5. input_reader.py +22 -0
  6. rag.py +165 -0
  7. rag_image.jpg +0 -0
  8. requirements.txt +14 -0
  9. utils.py +33 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ /venv
2
+ /venv/*
3
+ .env
4
+ __pycache__
5
+ __pycache__/*
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: MLhouse RAG
3
- emoji: 😻
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.8.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MLhouse-RAG
3
+ emoji: 💻
4
+ colorFrom: indigo
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.24.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio.components import Textbox, Button, Slider, Checkbox
4
+ from AinaTheme import theme
5
+ from urllib.error import HTTPError
6
+
7
+ from rag import RAG
8
+ from utils import setup
9
+
10
+ MAX_NEW_TOKENS = 700
11
+ SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
12
+
13
+ setup()
14
+
15
+
16
+ rag = RAG(
17
+ hf_token=os.getenv("HF_TOKEN"),
18
+ embeddings_model=os.getenv("EMBEDDINGS"),
19
+ repo_name=os.getenv("REPO_NAME"),
20
+ )
21
+
22
+ # model_name=os.getenv("MODEL"),
23
+ # rerank_model=os.getenv("RERANK_MODEL"),
24
+ # rerank_number_contexts=int(os.getenv("RERANK_NUMBER_CONTEXTS"))
25
+
26
+ def generate(prompt, model_parameters):
27
+ try:
28
+ output, context, source = rag.get_response(prompt, model_parameters)
29
+ return output, context, source
30
+ except HTTPError as err:
31
+ if err.code == 400:
32
+ gr.Warning(
33
+ "The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET."
34
+ )
35
+ except:
36
+ gr.Warning(
37
+ "Inference endpoint is not available right now. Please try again later."
38
+ )
39
+ return None, None, None
40
+
41
+
42
+ def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
43
+ if input_.strip() == "":
44
+ gr.Warning("Not possible to inference an empty input")
45
+ return None
46
+
47
+
48
+ model_parameters = {
49
+ "NUM_CHUNKS": num_chunks,
50
+ "max_new_tokens": max_new_tokens,
51
+ "repetition_penalty": repetition_penalty,
52
+ "top_k": top_k,
53
+ "top_p": top_p,
54
+ "do_sample": do_sample,
55
+ "temperature": temperature
56
+ }
57
+
58
+ output, context, source = generate(input_, model_parameters)
59
+ sources_markup = ""
60
+
61
+ for url in source:
62
+ sources_markup += f'<a href="{url}" target="_blank">{url}</a><br>'
63
+
64
+ return output, sources_markup, context
65
+ # return output.strip(), sources_markup, context
66
+
67
+
68
+ def change_interactive(text):
69
+ if len(text) == 0:
70
+ return gr.update(interactive=True), gr.update(interactive=False)
71
+ return gr.update(interactive=True), gr.update(interactive=True)
72
+
73
+
74
+ def clear():
75
+ return (
76
+ None,
77
+ None,
78
+ None,
79
+ None,
80
+ gr.Slider(value=2.0),
81
+ gr.Slider(value=MAX_NEW_TOKENS),
82
+ gr.Slider(value=1.0),
83
+ gr.Slider(value=50),
84
+ gr.Slider(value=0.99),
85
+ gr.Checkbox(value=False),
86
+ gr.Slider(value=0.35),
87
+ )
88
+
89
+
90
+ def gradio_app():
91
+ with gr.Blocks(theme=theme) as demo:
92
+ with gr.Row():
93
+ with gr.Column(scale=0.1):
94
+ gr.Image("rag_image.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False)
95
+ with gr.Column():
96
+ gr.Markdown(
97
+ """# Demo de Retrieval-Augmented Generation per documents legals
98
+ 🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
99
+ en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta
100
+ fent servir només la informació existent en els documents del repositori.
101
+
102
+ 🎯 **Objectiu:** Aquest és un demostrador amb la normativa vigent publicada al Diari Oficial de la Generalitat de Catalunya, en el
103
+ repositori del EADOP (Entitat Autònoma del Diari Oficial i de Publicacions). Aquesta versió explora prop de 2000 documents en català,
104
+ i genera la resposta fent servir el model Salamandra-7b-aligned-EADOP, el model BSC-LT/salamandra-7b-instruct alineat amb el dataset de alinia/EADOP-RAG-out-of-domain.
105
+
106
+ ⚠️ **Advertencies**: Aquesta versió és experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte.
107
+ Si us plau, tingueu-ho en compte quan exploreu aquest recurs.
108
+ """
109
+ )
110
+ with gr.Row(equal_height=True):
111
+ with gr.Column(variant="panel"):
112
+ input_ = Textbox(
113
+ lines=11,
114
+ label="Input",
115
+ placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
116
+ # value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
117
+ )
118
+ with gr.Row(variant="panel"):
119
+ clear_btn = Button(
120
+ "Clear",
121
+ )
122
+ submit_btn = Button("Submit", variant="primary", interactive=False)
123
+
124
+ with gr.Row(variant="panel"):
125
+ with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
126
+ num_chunks = Slider(
127
+ minimum=1,
128
+ maximum=6,
129
+ step=1,
130
+ value=2,
131
+ label="Number of chunks"
132
+ )
133
+ max_new_tokens = Slider(
134
+ minimum=50,
135
+ maximum=2000,
136
+ step=1,
137
+ value=MAX_NEW_TOKENS,
138
+ label="Max tokens"
139
+ )
140
+ repetition_penalty = Slider(
141
+ minimum=0.1,
142
+ maximum=2.0,
143
+ step=0.1,
144
+ value=1.0,
145
+ label="Repetition penalty"
146
+ )
147
+ top_k = Slider(
148
+ minimum=1,
149
+ maximum=100,
150
+ step=1,
151
+ value=50,
152
+ label="Top k"
153
+ )
154
+ top_p = Slider(
155
+ minimum=0.01,
156
+ maximum=0.99,
157
+ value=0.99,
158
+ label="Top p"
159
+ )
160
+ do_sample = Checkbox(
161
+ value=False,
162
+ label="Do sample"
163
+ )
164
+ temperature = Slider(
165
+ minimum=0.1,
166
+ maximum=1,
167
+ value=0.35,
168
+ label="Temperature"
169
+ )
170
+
171
+ parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
172
+
173
+ with gr.Column(variant="panel"):
174
+ output = Textbox(
175
+ lines=10,
176
+ label="Output",
177
+ interactive=False,
178
+ show_copy_button=True
179
+ )
180
+ with gr.Accordion("Sources and context:", open=False):
181
+ source_context = gr.Markdown(
182
+ label="Sources",
183
+ show_label=False,
184
+ )
185
+ with gr.Accordion("See full context evaluation:", open=False):
186
+ context_evaluation = gr.Markdown(
187
+ label="Full context",
188
+ show_label=False,
189
+ # interactive=False,
190
+ # autoscroll=False,
191
+ # show_copy_button=True
192
+ )
193
+
194
+
195
+ input_.change(
196
+ fn=change_interactive,
197
+ inputs=[input_],
198
+ outputs=[clear_btn, submit_btn],
199
+ api_name=False,
200
+ )
201
+
202
+ input_.change(
203
+ fn=None,
204
+ inputs=[input_],
205
+ api_name=False,
206
+ js="""(i, m) => {
207
+ document.getElementById('inputlenght').textContent = i.length + ' '
208
+ document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
209
+ }""",
210
+ )
211
+
212
+ clear_btn.click(
213
+ fn=clear,
214
+ inputs=[],
215
+ outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
216
+ queue=False,
217
+ api_name=False
218
+ )
219
+
220
+ submit_btn.click(
221
+ fn=submit_input,
222
+ inputs=[input_]+ parameters_compontents,
223
+ outputs=[output, source_context, context_evaluation],
224
+ api_name="get-results"
225
+ )
226
+
227
+ with gr.Row():
228
+ with gr.Column(scale=0.5):
229
+ gr.Examples(
230
+ examples=[
231
+ ["""Què és l'EADOP (Entitat Autònoma del Diari Oficial i de Publicacions)?"""],
232
+ ],
233
+ inputs=input_,
234
+ outputs=[output, source_context, context_evaluation],
235
+ fn=submit_input,
236
+ )
237
+ gr.Examples(
238
+ examples=[
239
+ ["""Què diu el decret sobre la senyalització de les begudes alcohòliques i el tabac a Catalunya?"""],
240
+ ],
241
+ inputs=input_,
242
+ outputs=[output, source_context, context_evaluation],
243
+ fn=submit_input,
244
+ )
245
+ gr.Examples(
246
+ examples=[
247
+ ["""Com es pot inscriure una persona al Registre de catalans i catalanes residents a l'exterior?"""],
248
+ ],
249
+ inputs=input_,
250
+ outputs=[output, source_context, context_evaluation],
251
+ fn=submit_input,
252
+ )
253
+ gr.Examples(
254
+ examples=[
255
+ ["""Quina és la finalitat del Servei Meterològic de Catalunya ?"""],
256
+ ],
257
+ inputs=input_,
258
+ outputs=[output, source_context, context_evaluation],
259
+ fn=submit_input,
260
+ )
261
+
262
+ demo.launch(show_api=True)
263
+
264
+
265
+ if __name__ == "__main__":
266
+ gradio_app()
handler.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ class ContentHandler():
4
+ content_type = "application/json"
5
+ accepts = "application/json"
6
+
7
+ def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
8
+ input_str = json.dumps({'inputs': prompt, 'parameters': model_kwargs})
9
+ return input_str.encode('utf-8')
10
+
11
+ def transform_output(self, output: bytes) -> str:
12
+ response_json = json.loads(output.read().decode("utf-8"))
13
+ return response_json[0]["generated_text"]
14
+
input_reader.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from llama_index.core.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
4
+ from llama_index.core.readers import SimpleDirectoryReader
5
+ from llama_index.core.schema import Document
6
+ from llama_index.core import Settings
7
+
8
+
9
+ class InputReader:
10
+ def __init__(self, input_dir: str) -> None:
11
+ self.reader = SimpleDirectoryReader(input_dir=input_dir)
12
+
13
+ def parse_documents(
14
+ self,
15
+ show_progress: bool = True,
16
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
17
+ chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
18
+ ) -> List[Document]:
19
+ Settings.chunk_size = chunk_size
20
+ Settings.chunk_overlap = chunk_overlap
21
+ documents = self.reader.load_data(show_progress=show_progress)
22
+ return documents
rag.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import requests
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
+ from openai import OpenAI
7
+ from huggingface_hub import snapshot_download
8
+
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+
12
+
13
+ class RAG:
14
+ NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
15
+
16
+
17
+
18
+ # Download the vectorstore from Hugging Face Hub
19
+
20
+ def __init__(self, hf_token, embeddings_model, repo_name):
21
+
22
+ vectorstore = snapshot_download(repo_name)
23
+
24
+
25
+ # self.model_name = model_name
26
+ self.hf_token = hf_token
27
+ # self.rerank_model = rerank_model
28
+ # self.rerank_number_contexts = rerank_number_contexts
29
+
30
+ # load vectore store
31
+ embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
32
+ self.vectore_store = FAISS.load_local(vectorstore, embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True)
33
+
34
+ logging.info("RAG loaded!")
35
+
36
+ # def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
37
+ # """
38
+ # Rerank the contexts based on their relevance to the given instruction.
39
+ # """
40
+
41
+ # rerank_model = self.rerank_model
42
+
43
+
44
+ # tokenizer = AutoTokenizer.from_pretrained(rerank_model)
45
+ # model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
46
+
47
+ # def get_score(query, passage):
48
+ # """Calculate the relevance score of a passage with respect to a query."""
49
+
50
+
51
+ # inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
52
+
53
+
54
+ # with torch.no_grad():
55
+ # outputs = model(**inputs)
56
+
57
+
58
+ # logits = outputs.logits
59
+
60
+
61
+ # score = logits.view(-1, ).float()
62
+
63
+
64
+ # return score
65
+
66
+ # scores = [get_score(instruction, c[0].page_content) for c in contexts]
67
+ # combined = list(zip(contexts, scores))
68
+ # sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
69
+ # sorted_texts, _ = zip(*sorted_combined)
70
+
71
+ # return sorted_texts[:number_of_contexts]
72
+
73
+ def get_context(self, instruction, number_of_contexts=2):
74
+ """Retrieve the most relevant contexts for a given instruction."""
75
+ documentos = self.vectore_store.similarity_search_with_score(instruction, k=4)
76
+
77
+ # documentos = self.rerank_contexts(instruction, documentos, number_of_contexts=number_of_contexts)
78
+
79
+ print("Reranked documents")
80
+ return documentos
81
+
82
+ def predict_dolly(self, instruction, context, model_parameters):
83
+
84
+ api_key = os.getenv("HF_TOKEN")
85
+
86
+
87
+ headers = {
88
+ "Accept" : "application/json",
89
+ "Authorization": f"Bearer {api_key}",
90
+ "Content-Type": "application/json"
91
+ }
92
+
93
+ query = f"### Instruction\n{instruction}\n\n### Context\n{context}\n\n### Answer\n "
94
+ #prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
95
+
96
+
97
+ payload = {
98
+ "inputs": query,
99
+ "parameters": model_parameters
100
+ }
101
+
102
+ response = requests.post(self.model_name, headers=headers, json=payload)
103
+
104
+ return response.json()[0]["generated_text"].split("###")[-1][8:]
105
+
106
+ def predict_completion(self, instruction, context, model_parameters):
107
+
108
+ client = OpenAI(
109
+ base_url=os.getenv("MODEL"),
110
+ api_key=os.getenv("HF_TOKEN")
111
+ )
112
+
113
+ query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
114
+
115
+ chat_completion = client.chat.completions.create(
116
+ model="tgi",
117
+ messages=[
118
+ {"role": "user", "content": instruction}
119
+ ],
120
+ temperature=model_parameters["temperature"],
121
+ max_tokens=model_parameters["max_new_tokens"],
122
+ stream=False,
123
+ stop=["<|im_end|>"],
124
+ extra_body = {
125
+ "presence_penalty": model_parameters["repetition_penalty"] - 2,
126
+ "do_sample": False
127
+ }
128
+ )
129
+
130
+ response = chat_completion.choices[0].message.content
131
+
132
+ return response
133
+
134
+
135
+ def beautiful_context(self, docs):
136
+
137
+ text_context = ""
138
+
139
+ full_context = ""
140
+ source_context = []
141
+ for doc in docs:
142
+ text_context += doc[0].page_content
143
+ full_context += doc[0].page_content + "\n"
144
+ full_context += doc[0].metadata["Títol de la norma"] + "\n\n"
145
+ full_context += doc[0].metadata["url"] + "\n\n"
146
+ source_context.append(doc[0].metadata["url"])
147
+
148
+ return text_context, full_context, source_context
149
+
150
+ def get_response(self, prompt: str, model_parameters: dict) -> str:
151
+ try:
152
+ docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
153
+ text_context, full_context, source = self.beautiful_context(docs)
154
+
155
+ del model_parameters["NUM_CHUNKS"]
156
+
157
+ # response = self.predict_completion(prompt, text_context, model_parameters)
158
+ response = "Output"
159
+
160
+ if not response:
161
+ return self.NO_ANSWER_MESSAGE
162
+
163
+ return response, full_context, source
164
+ except Exception as err:
165
+ print(err)
rag_image.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.29.0
2
+ huggingface-hub==0.23.4
3
+ python-dotenv==1.0.0
4
+ llama-index==0.10.14
5
+ llama-index-embeddings-huggingface==0.2.2
6
+ llama-index-llms-huggingface==0.2.4
7
+ sentence-transformers==2.7.0
8
+ langchain
9
+ faiss-cpu
10
+ aina-gradio-theme==2.3
11
+
12
+ langchain-community==0.2.1
13
+ langchain-core==0.2.1
14
+ openai==1.35.12
utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+
4
+ from dotenv import load_dotenv
5
+
6
+
7
+ from rag import RAG
8
+
9
+ USER_INPUT = 100
10
+
11
+
12
+ def setup():
13
+ load_dotenv()
14
+ warnings.filterwarnings("ignore")
15
+
16
+ logging.addLevelName(USER_INPUT, "USER_INPUT")
17
+ logging.basicConfig(format="[%(levelname)s]: %(message)s", level=logging.INFO)
18
+
19
+
20
+ def interactive(model: RAG):
21
+ logging.info("Write `exit` when you want to stop the model.")
22
+ print()
23
+
24
+ query = ""
25
+ while query.lower() != "exit":
26
+ logging.log(USER_INPUT, "Write the query or `exit`:")
27
+ query = input()
28
+
29
+ if query.lower() == "exit":
30
+ break
31
+
32
+ response = model.get_response(query)
33
+ print(response, end="\n\n")