Marina Pliusnina commited on
Commit
2217335
1 Parent(s): 2b8b263
Files changed (9) hide show
  1. README.md +4 -4
  2. app.py +128 -0
  3. gitignore +4 -0
  4. handler.py +14 -0
  5. input_reader.py +22 -0
  6. rag.py +73 -0
  7. rag_image.jpg +0 -0
  8. requirements.txt +8 -0
  9. utils.py +33 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: EADOP RAG
3
- emoji: 📈
4
- colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: Rag
3
+ emoji: 💻
4
+ colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.14.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio.components import Textbox, Button
4
+ # from AinaTheme import theme
5
+ from urllib.error import HTTPError
6
+
7
+ from rag import RAG
8
+ from utils import setup
9
+
10
+ setup()
11
+
12
+
13
+ rag = RAG(
14
+ hf_token=os.getenv("HF_TOKEN"),
15
+ embeddings_model=os.getenv("EMBEDDINGS"),
16
+ model_name=os.getenv("MODEL"),
17
+
18
+
19
+ )
20
+
21
+
22
+ def generate(prompt):
23
+ try:
24
+ output = rag.get_response(prompt)
25
+ return output
26
+ except HTTPError as err:
27
+ if err.code == 400:
28
+ gr.Warning(
29
+ "The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET."
30
+ )
31
+ except:
32
+ gr.Warning(
33
+ "Inference endpoint is not available right now. Please try again later."
34
+ )
35
+
36
+
37
+ def submit_input(input_):
38
+ if input_.strip() == "":
39
+ gr.Warning("Not possible to inference an empty input")
40
+ return None
41
+
42
+ output = generate(input_)
43
+
44
+ return output
45
+
46
+
47
+ def change_interactive(text):
48
+ if len(text) == 0:
49
+ return gr.update(interactive=True), gr.update(interactive=False)
50
+ return gr.update(interactive=True), gr.update(interactive=True)
51
+
52
+
53
+ def clear():
54
+ return (
55
+ None,
56
+ None,
57
+ )
58
+
59
+
60
+ def gradio_app():
61
+ with gr.Blocks() as demo:
62
+ with gr.Row():
63
+ with gr.Column(scale=0.1):
64
+ gr.Image("rag_image.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False)
65
+ with gr.Column():
66
+ gr.Markdown(
67
+ """# Retrieval-Augmented Generation (experimental)
68
+ 🔍 **Retrieval-Augmented Generation** (RAG) is an AI framework for improving the quality of LLM-generated responses
69
+ by grounding the model on external sources of knowledge to supplement the LLM's internal representation of
70
+ information. Implementing RAG in an LLM-based question answering system has two main benefits: It ensures
71
+ that the model has access to the most current, reliable facts, and that users have access to the model's
72
+ sources, ensuring that the information can be checked for accuracy and ultimately trusted.
73
+
74
+ 🎯 **Purpose:** The main purpose of this RAG is answering questions related to the [AI ACT](https://artificialintelligenceact.eu/wp-content/uploads/2024/01/AI-Act-FullText.pdf).
75
+ By incorporating external knowledge sources, RAG enables the LLM to provide more informed and reliable
76
+ responses specifically tailored to inquiries about it.
77
+ ⚠️ **Limitations**: This version is for beta testing only. The content generated by these models is unsupervised
78
+ and might be wrong. Please bear this in mind when exploring this resource.
79
+ """
80
+ )
81
+ with gr.Row(equal_height=True):
82
+ with gr.Column(variant="panel"):
83
+ input_ = Textbox(
84
+ lines=11,
85
+ label="Input",
86
+ placeholder="e.g. What is the AI Act?",
87
+ # value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
88
+ )
89
+
90
+ with gr.Column(variant="panel"):
91
+ output = Textbox(
92
+ lines=11, label="Output", interactive=False, show_copy_button=True
93
+ )
94
+ with gr.Row(variant="panel"):
95
+ clear_btn = Button(
96
+ "Clear",
97
+ )
98
+ submit_btn = Button("Submit", variant="primary", interactive=False)
99
+
100
+ input_.change(
101
+ fn=change_interactive,
102
+ inputs=[input_],
103
+ outputs=[clear_btn, submit_btn],
104
+ api_name=False,
105
+ )
106
+
107
+ input_.change(
108
+ fn=None,
109
+ inputs=[input_],
110
+ api_name=False,
111
+ js="""(i, m) => {
112
+ document.getElementById('inputlenght').textContent = i.length + ' '
113
+ document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
114
+ }""",
115
+ )
116
+
117
+ clear_btn.click(
118
+ fn=clear, inputs=[], outputs=[input_, output], queue=False, api_name=False
119
+ )
120
+ submit_btn.click(
121
+ fn=submit_input, inputs=[input_], outputs=[output], api_name="get-results"
122
+ )
123
+
124
+ demo.launch(show_api=True)
125
+
126
+
127
+ if __name__ == "__main__":
128
+ gradio_app()
gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ venv
2
+ **/__pycache__
3
+ .env
4
+ vectorestore/
handler.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ class ContentHandler():
4
+ content_type = "application/json"
5
+ accepts = "application/json"
6
+
7
+ def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
8
+ input_str = json.dumps({'inputs': prompt, 'parameters': model_kwargs})
9
+ return input_str.encode('utf-8')
10
+
11
+ def transform_output(self, output: bytes) -> str:
12
+ response_json = json.loads(output.read().decode("utf-8"))
13
+ return response_json[0]["generated_text"]
14
+
input_reader.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from llama_index.core.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
4
+ from llama_index.core.readers import SimpleDirectoryReader
5
+ from llama_index.core.schema import Document
6
+ from llama_index.core import Settings
7
+
8
+
9
+ class InputReader:
10
+ def __init__(self, input_dir: str) -> None:
11
+ self.reader = SimpleDirectoryReader(input_dir=input_dir)
12
+
13
+ def parse_documents(
14
+ self,
15
+ show_progress: bool = True,
16
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
17
+ chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
18
+ ) -> List[Document]:
19
+ Settings.chunk_size = chunk_size
20
+ Settings.chunk_overlap = chunk_overlap
21
+ documents = self.reader.load_data(show_progress=show_progress)
22
+ return documents
rag.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import requests
4
+
5
+
6
+
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+
10
+
11
+ class RAG:
12
+ NO_ANSWER_MESSAGE: str = "Sorry, I couldn't answer your question."
13
+
14
+
15
+ def __init__(self, hf_token, embeddings_model, model_name):
16
+
17
+
18
+ self.model_name = model_name
19
+ self.hf_token = hf_token
20
+
21
+ # load vectore store
22
+ embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
23
+ self.vectore_store = FAISS.load_local("vectorestore", embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True)
24
+
25
+ logging.info("RAG loaded!")
26
+
27
+ def get_context(self, instruction, number_of_contexts=1):
28
+
29
+ context = ""
30
+
31
+
32
+ documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
33
+
34
+
35
+ for doc in documentos:
36
+
37
+ context += doc[0].page_content
38
+
39
+ return context
40
+
41
+ def predict(self, instruction, context):
42
+
43
+ api_key = os.getenv("HF_TOKEN")
44
+
45
+
46
+ headers = {
47
+ "Accept" : "application/json",
48
+ "Authorization": f"Bearer {api_key}",
49
+ "Content-Type": "application/json"
50
+ }
51
+
52
+ query = f"### Instruction\n{instruction}\n\n### Context\n{context}\n\n### Answer\n "
53
+
54
+
55
+ payload = {
56
+ "inputs": query,
57
+ "parameters": {}
58
+ }
59
+
60
+ response = requests.post(self.model_name, headers=headers, json=payload)
61
+
62
+ return response.json()[0]["generated_text"].split("###")[-1][8:-1]
63
+
64
+ def get_response(self, prompt: str) -> str:
65
+
66
+ context = self.get_context(prompt)
67
+
68
+ response = self.predict(prompt, context)
69
+
70
+ if not response:
71
+ return self.NO_ANSWER_MESSAGE
72
+
73
+ return response
rag_image.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==4.14.0
2
+ python-dotenv==1.0.0
3
+ llama-index==0.10.14
4
+ llama-index-embeddings-huggingface==0.1.4
5
+ llama-index-llms-huggingface==0.1.3
6
+ sentence-transformers
7
+ langchain
8
+ faiss-cpu
utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+
4
+ from dotenv import load_dotenv
5
+
6
+
7
+ from rag import RAG
8
+
9
+ USER_INPUT = 100
10
+
11
+
12
+ def setup():
13
+ load_dotenv()
14
+ warnings.filterwarnings("ignore")
15
+
16
+ logging.addLevelName(USER_INPUT, "USER_INPUT")
17
+ logging.basicConfig(format="[%(levelname)s]: %(message)s", level=logging.INFO)
18
+
19
+
20
+ def interactive(model: RAG):
21
+ logging.info("Write `exit` when you want to stop the model.")
22
+ print()
23
+
24
+ query = ""
25
+ while query.lower() != "exit":
26
+ logging.log(USER_INPUT, "Write the query or `exit`:")
27
+ query = input()
28
+
29
+ if query.lower() == "exit":
30
+ break
31
+
32
+ response = model.get_response(query)
33
+ print(response, end="\n\n")