Quentin Fisch commited on
Commit
efb5688
1 Parent(s): 5c4f525

feat(demo): add demo files

Browse files
Files changed (3) hide show
  1. app.py +79 -0
  2. confluence_rag.py +185 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI for Mistral 7B with RAG
3
+ """
4
+
5
+ import os
6
+ from typing import List
7
+
8
+ import gradio as gr
9
+ from langchain_core.runnables.base import RunnableSequence
10
+ import numpy as np
11
+ from confluence_rag import generate_rag_chain, load_pdf, store_vector, load_multiple_pdf
12
+
13
+
14
+ def initialize_chain(file: gr.File) -> RunnableSequence:
15
+ """
16
+ Initializes the chain with the given file.
17
+
18
+ If no file is provided, the llm is used without RAG.
19
+
20
+ Args:
21
+ file (gr.File): file to initialize the chain with
22
+
23
+ Returns:
24
+ RunnableSequence: the chain
25
+ """
26
+ if file is None:
27
+ return generate_rag_chain()
28
+
29
+ if len(file) == 1:
30
+ pdf = load_pdf(file[0].name)
31
+ else:
32
+ pdf = load_multiple_pdf([f.name for f in file])
33
+ retriever = store_vector(pdf)
34
+
35
+ return generate_rag_chain(retriever)
36
+
37
+
38
+ def invoke_chain(message: str, history: List[str], file: gr.File = None) -> str:
39
+ """
40
+ Invokes the chain with the given message and updates the chain if a new file is provided.
41
+
42
+ Args:
43
+ message (str): message to invoke the chain with
44
+ history (List[str]): history of messages
45
+ file (gr.File, optional): file to update the chain with. Defaults to None.
46
+
47
+ Returns:
48
+ str: the response of the chain
49
+ """
50
+ # Check if file is provided and exists
51
+ if file is not None and not np.all([os.path.exists(f.name) for f in file]) or len(file) == 0:
52
+ return "Error: File not found."
53
+
54
+ if file is not None and not np.all([f.name.endswith(".pdf") for f in file]):
55
+ return "Error: File is not a pdf."
56
+
57
+ chain = initialize_chain(file)
58
+ return chain.invoke(message)
59
+
60
+
61
+ def create_demo() -> gr.Interface:
62
+ """
63
+ Creates and returns a Gradio Chat Interface.
64
+
65
+ Returns:
66
+ gr.Interface: the Gradio Chat Interface
67
+ """
68
+ return gr.ChatInterface(
69
+ invoke_chain,
70
+ additional_inputs=[gr.File(label="File", file_count='multiple')],
71
+ title="Mistral 7B with RAG",
72
+ description="Ask questions to Mistral about your pdf document.",
73
+ theme="soft",
74
+ )
75
+
76
+
77
+ if __name__ == "__main__":
78
+ demo = create_demo()
79
+ demo.launch()
confluence_rag.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ from langchain_community.document_loaders import UnstructuredPDFLoader
5
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
6
+ from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
7
+ from langchain.prompts import ChatPromptTemplate
8
+ from langchain.schema.output_parser import StrOutputParser
9
+ from langchain.schema.runnable import RunnablePassthrough
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores.chroma import Chroma
12
+ from langchain_core.runnables.base import RunnableSequence
13
+ from langchain_core.vectorstores import VectorStoreRetriever
14
+
15
+ from dotenv import load_dotenv
16
+
17
+
18
+ load_dotenv()
19
+ HF_API_KEY = os.environ["HF_API_KEY"]
20
+
21
+
22
+ class MistralOutputParser(StrOutputParser):
23
+ """OutputParser that parser llm result from Mistral API"""
24
+
25
+ def parse(self, text: str) -> str:
26
+ """
27
+ Returns the input text with no changes.
28
+
29
+ Args:
30
+ text (str): text to parse
31
+
32
+ Returns:
33
+ str: parsed text
34
+ """
35
+ return text.split("[/INST]")[-1].strip()
36
+
37
+
38
+ def load_pdf(
39
+ document_path: str,
40
+ mode: str = "single",
41
+ strategy: str = "fast",
42
+ chunk_size: int = 500,
43
+ chunk_overlap: int = 0,
44
+ ) -> List[str]:
45
+ """
46
+ Load a pdf document and split it into chunks of text.
47
+
48
+ Args:
49
+ document_path (Path): path to the pdf document
50
+ mode (str, optional): mode of the loader. Defaults to "single".
51
+ strategy (str, optional): strategy of the loader. Defaults to "fast".
52
+ chunk_size (int, optional): size of the chunks. Defaults to 500.
53
+ chunk_overlap (int, optional): overlap of the chunks. Defaults to 0.
54
+
55
+ Returns:
56
+ List[str]: list of chunks of text
57
+ """
58
+
59
+ # Load the document
60
+ loader = UnstructuredPDFLoader(
61
+ document_path,
62
+ mode=mode,
63
+ strategy=strategy,
64
+ )
65
+
66
+ docs = loader.load()
67
+
68
+ # Split the document into chunks of text
69
+ text_splitter = RecursiveCharacterTextSplitter(
70
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap
71
+ )
72
+ all_splits = text_splitter.split_documents(docs)
73
+
74
+ return all_splits
75
+
76
+
77
+ def store_vector(all_splits: List[str]) -> VectorStoreRetriever:
78
+ """
79
+ Store vector of each chunk of text.
80
+
81
+ Args:
82
+ all_splits (List[str]): list of chunks of text
83
+
84
+ Returns:
85
+ VectorStoreRetriever: retriever that can be used to retrieve the vector of a chunk of text
86
+ """
87
+
88
+ # Use the HuggingFace distilbert-base-uncased model to embed the text
89
+ embeddings_model_url = (
90
+ "https://api-inference.huggingface.co/models/distilbert-base-uncased"
91
+ )
92
+
93
+ embeddings = HuggingFaceInferenceAPIEmbeddings(
94
+ endpoint_url=embeddings_model_url,
95
+ api_key=HF_API_KEY,
96
+ )
97
+
98
+ # Store the embeddings of each chunk of text into ChromaDB
99
+ vector_store = Chroma.from_documents(all_splits, embeddings)
100
+ retriever = vector_store.as_retriever()
101
+
102
+ return retriever
103
+
104
+
105
+ def generate_mistral_rag_prompt() -> ChatPromptTemplate:
106
+ """
107
+ Generate a prompt for Mistral API wiht RAG.
108
+
109
+ Returns:
110
+ ChatPromptTemplate: prompt for Mistral API
111
+ """
112
+ template = "<s>[INST] {context} {prompt} [/INST]"
113
+ prompt_template = ChatPromptTemplate.from_template(template)
114
+ return prompt_template
115
+
116
+
117
+ def generate_mistral_simple_prompt() -> ChatPromptTemplate:
118
+ """
119
+ Generate a simple prompt for Mistral without RAG.
120
+
121
+ Returns:
122
+ ChatPromptTemplate: prompt for Mistral API
123
+ """
124
+ template = "[INST] {prompt} [/INST]"
125
+ prompt_template = ChatPromptTemplate.from_template(template)
126
+ return prompt_template
127
+
128
+
129
+ def generate_rag_chain(retriever: VectorStoreRetriever = None) -> RunnableSequence:
130
+ """
131
+ Generate a RAG chain with Mistral API and ChromaDB.
132
+
133
+ Args:
134
+ Retriever (VectorStoreRetriever): retriever that can be used to retrieve the vector of a chunk of text
135
+
136
+ Returns:
137
+ RunnableSequence: RAG chain
138
+ """
139
+ # Use the Mistral Free prototype API
140
+ mistral_url = (
141
+ "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
142
+ )
143
+
144
+ model_endpoint = HuggingFaceEndpoint(
145
+ endpoint_url=mistral_url,
146
+ huggingfacehub_api_token=HF_API_KEY,
147
+ task="text2text-generation",
148
+ )
149
+
150
+ # Use a custom output parser
151
+ output_parser = MistralOutputParser()
152
+
153
+ # If no retriever is provided, use a simple prompt
154
+ if retriever is None:
155
+ entry = {"prompt": RunnablePassthrough()}
156
+ return entry | generate_mistral_simple_prompt() | model_endpoint | output_parser
157
+
158
+ # If a retriever is provided, use a RAG prompt
159
+ retrieval = {"context": retriever, "prompt": RunnablePassthrough()}
160
+
161
+ return retrieval | generate_mistral_rag_prompt() | model_endpoint | output_parser
162
+
163
+
164
+ def load_multiple_pdf(document_paths: List[str]) -> List[str]:
165
+ """
166
+ Load multiple pdf documents and split them into chunks of text.
167
+
168
+ Args:
169
+ document_paths (List[str]): list of paths to the pdf documents
170
+
171
+ Returns:
172
+ List[str]: list of chunks of text
173
+ """
174
+ docs = []
175
+ for document_path in document_paths:
176
+ loader = UnstructuredPDFLoader(
177
+ document_path,
178
+ mode="single",
179
+ strategy="fast",
180
+ )
181
+ docs.extend(loader.load())
182
+
183
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=25)
184
+ all_splits = text_splitter.split_documents(docs)
185
+ return all_splits
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ langchain==0.1.9
2
+ chromadb==0.4.24
3
+ unstructured[pdf]
4
+ gradio