vteam27 commited on
Commit
5a08d5f
1 Parent(s): 57b1a45
Files changed (3) hide show
  1. app.py +147 -0
  2. requirements.txt +21 -5
  3. utils.py +59 -1
app.py CHANGED
@@ -14,6 +14,7 @@ from happytransformer import HappyTextToText, TTSettings
14
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,logging
15
  from transformers.integrations import deepspeed
16
  import re
 
17
  from lang_list import (
18
  LANGUAGE_NAME_TO_CODE,
19
  T2TT_TARGET_LANGUAGE_NAMES,
@@ -251,12 +252,158 @@ with gr.Blocks() as demo_t2tt:
251
  api_name="t2tt",
252
  )
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  with gr.Blocks() as demo:
255
  with gr.Tabs():
256
  with gr.Tab(label="OCR"):
257
  demo_ocr.render()
258
  with gr.Tab(label="Translate"):
259
  demo_t2tt.render()
 
 
260
 
261
  if __name__ == "__main__":
262
  demo.launch()
 
14
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,logging
15
  from transformers.integrations import deepspeed
16
  import re
17
+ import torch
18
  from lang_list import (
19
  LANGUAGE_NAME_TO_CODE,
20
  T2TT_TARGET_LANGUAGE_NAMES,
 
252
  api_name="t2tt",
253
  )
254
 
255
+
256
+ #RAG
257
+ import utils
258
+ from langchain_mistralai import ChatMistralAI
259
+ from langchain_core.prompts import ChatPromptTemplate
260
+ from langchain_core.output_parsers import StrOutputParser
261
+ from langchain_community.vectorstores import Chroma
262
+ from langchain_huggingface import HuggingFaceEmbeddings
263
+ from langchain_core.runnables import RunnablePassthrough
264
+ os.environ['MISTRAL_API_KEY'] = 'XuyOObDE7trMbpAeI7OXYr3dnmoWy3L0'
265
+
266
+ class VectorData():
267
+ def __init__(self):
268
+ embedding_model_name = 'l3cube-pune/punjabi-sentence-similarity-sbert'
269
+
270
+ model_kwargs = {'device':'cuda' if torch.cuda.is_available() else 'cpu',"trust_remote_code": True}
271
+
272
+ self.embeddings = HuggingFaceEmbeddings(
273
+ model_name=embedding_model_name,
274
+ model_kwargs=model_kwargs
275
+ )
276
+
277
+ self.vectorstore = Chroma(persist_directory="chroma_db", embedding_function=self.embeddings)
278
+ self.retriever = self.vectorstore.as_retriever()
279
+ self.ingested_files = []
280
+ self.prompt = ChatPromptTemplate.from_messages(
281
+ [
282
+ (
283
+ "system",
284
+ """Answer the question based on the given context. Dont give any ans if context is not valid to question. Always give the source of context:
285
+ {context}
286
+ """,
287
+ ),
288
+ ("human", "{question}"),
289
+ ]
290
+ )
291
+ self.llm = ChatMistralAI(model="mistral-large-latest")
292
+ self.rag_chain = (
293
+ {"context": self.retriever, "question": RunnablePassthrough()}
294
+ | self.prompt
295
+ | self.llm
296
+ | StrOutputParser()
297
+ )
298
+
299
+ def add_file(self,file):
300
+ if file is not None:
301
+ self.ingested_files.append(file.name.split('/')[-1])
302
+ self.retriever, self.vectorstore = utils.add_doc(file,self.vectorstore)
303
+ self.rag_chain = (
304
+ {"context": self.retriever, "question": RunnablePassthrough()}
305
+ | self.prompt
306
+ | self.llm
307
+ | StrOutputParser()
308
+ )
309
+ return [[name] for name in self.ingested_files]
310
+
311
+ def delete_file_by_name(self,file_name):
312
+ if file_name in self.ingested_files:
313
+ self.retriever, self.vectorstore = utils.delete_doc(file_name,self.vectorstore)
314
+ self.ingested_files.remove(file_name)
315
+ return [[name] for name in self.ingested_files]
316
+
317
+ def delete_all_files(self):
318
+ self.ingested_files.clear()
319
+ self.retriever, self.vectorstore = utils.delete_all_doc(self.vectorstore)
320
+ return []
321
+
322
+ data_obj = VectorData()
323
+
324
+ # Function to handle question answering
325
+ def answer_question(question):
326
+ if question.strip():
327
+ return f'{data_obj.rag_chain.invoke(question)}'
328
+ return "Please enter a question."
329
+
330
+ with gr.Blocks() as rag_interface:
331
+ # Title and Description
332
+ gr.Markdown("# RAG Interface")
333
+ gr.Markdown("Manage documents and ask questions with a Retrieval-Augmented Generation (RAG) system.")
334
+
335
+ with gr.Row():
336
+ # Left Column: File Management
337
+ with gr.Column():
338
+ gr.Markdown("### File Management")
339
+
340
+ # File upload and ingest
341
+ file_input = gr.File(label="Upload File to Ingest")
342
+ add_file_button = gr.Button("Ingest File")
343
+
344
+ # Scrollable list for ingested files
345
+ ingested_files_box = gr.Dataframe(
346
+ headers=["Files"],
347
+ datatype="str",
348
+ row_count=4, # Limits the visible rows to create a scrollable view
349
+ interactive=False
350
+ )
351
+
352
+ # Radio buttons to choose delete option
353
+ delete_option = gr.Radio(choices=["Delete by File Name", "Delete All Files"], label="Delete Option")
354
+ file_name_input = gr.Textbox(label="Enter File Name to Delete", visible=False)
355
+ delete_button = gr.Button("Delete Selected")
356
+
357
+ # Show or hide file name input based on delete option selection
358
+ def toggle_file_input(option):
359
+ return gr.update(visible=(option == "Delete by File Name"))
360
+
361
+ delete_option.change(fn=toggle_file_input, inputs=delete_option, outputs=file_name_input)
362
+
363
+ # Handle file ingestion
364
+ add_file_button.click(
365
+ fn=data_obj.add_file,
366
+ inputs=file_input,
367
+ outputs=ingested_files_box
368
+ )
369
+
370
+ # Handle delete based on selected option
371
+ def delete_action(delete_option, file_name):
372
+ if delete_option == "Delete by File Name" and file_name:
373
+ return data_obj.delete_file_by_name(file_name)
374
+ elif delete_option == "Delete All Files":
375
+ return data_obj.delete_all_files()
376
+ else:
377
+ return [[name] for name in data_obj.ingested_files]
378
+
379
+ delete_button.click(
380
+ fn=delete_action,
381
+ inputs=[delete_option, file_name_input],
382
+ outputs=ingested_files_box
383
+ )
384
+
385
+ # Right Column: Question Answering
386
+ with gr.Column():
387
+ gr.Markdown("### Ask a Question")
388
+
389
+ # Question input
390
+ question_input = gr.Textbox(label="Enter your question")
391
+
392
+ # Get answer button and answer output
393
+ ask_button = gr.Button("Get Answer")
394
+ answer_output = gr.Textbox(label="Answer", interactive=False)
395
+
396
+ ask_button.click(fn=answer_question, inputs=question_input, outputs=answer_output)
397
+
398
+
399
  with gr.Blocks() as demo:
400
  with gr.Tabs():
401
  with gr.Tab(label="OCR"):
402
  demo_ocr.render()
403
  with gr.Tab(label="Translate"):
404
  demo_t2tt.render()
405
+ with gr.Tab(label="RAG"):
406
+ rag_interface.render()
407
 
408
  if __name__ == "__main__":
409
  demo.launch()
requirements.txt CHANGED
@@ -4,14 +4,30 @@ reportlab>=3.6.2
4
  PyPDF2==1.26.0
5
  happytransformer
6
  python-doctr[torch]@git+https://github.com/mindee/doctr.git
7
- transformers
8
  fairseq2==0.1
9
- pydub
10
  yt-dlp
11
  sentencepiece
12
  nltk
13
- numpy==1.26.4
14
  opencv-python==4.9.0.80
15
- packaging
16
  pillow==10.3.0
17
- pytesseract==0.3.10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  PyPDF2==1.26.0
5
  happytransformer
6
  python-doctr[torch]@git+https://github.com/mindee/doctr.git
 
7
  fairseq2==0.1
 
8
  yt-dlp
9
  sentencepiece
10
  nltk
 
11
  opencv-python==4.9.0.80
 
12
  pillow==10.3.0
13
+ pytesseract==0.3.10
14
+ packaging
15
+ torch
16
+ fastapi
17
+ uvicorn
18
+ pandas
19
+ numpy
20
+ torch
21
+ transformers
22
+ scikit-learn
23
+ sentence-transformers
24
+ langchain
25
+ langchain-community
26
+ langchain-core
27
+ langchain-huggingface
28
+ langchain-mistralai
29
+ langchain-text-splitters
30
+ langsmith
31
+ chroma-hnswlib
32
+ chromadb
33
+ fastapi
utils.py CHANGED
@@ -160,4 +160,62 @@ class HocrParser():
160
  if image is not None:
161
  pdf.drawImage(ImageReader(Image.fromarray(image)),
162
  0, 0, width=width, height=height)
163
- pdf.save()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  if image is not None:
161
  pdf.drawImage(ImageReader(Image.fromarray(image)),
162
  0, 0, width=width, height=height)
163
+ pdf.save()
164
+
165
+
166
+
167
+ from langchain_huggingface import HuggingFaceEmbeddings
168
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
169
+ from langchain_community.vectorstores import Chroma
170
+ from langchain.schema import Document
171
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
172
+ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
173
+ import torch
174
+
175
+ embedding_model_name = 'l3cube-pune/punjabi-sentence-similarity-sbert'
176
+
177
+ model_kwargs = {'device':'cuda' if torch.cuda.is_available() else 'cpu',"trust_remote_code": True}
178
+
179
+ embeddings = HuggingFaceEmbeddings(
180
+ model_name=embedding_model_name,
181
+ model_kwargs=model_kwargs
182
+ )
183
+
184
+ vectorstore = None
185
+
186
+
187
+
188
+ def read_file(data: str) -> Document:
189
+ f = open(data,'r')
190
+ content = f.read()
191
+ f.close()
192
+ doc = Document(page_content=content, metadata={"name": data.split('/')[-1]})
193
+ return doc
194
+
195
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
196
+
197
+ def add_doc(data,vectorstore):
198
+ doc = read_file(data)
199
+ splits = text_splitter.split_documents([doc])
200
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
201
+ retriever = vectorstore.as_retriever(search_kwargs={'k':1})
202
+ return retriever, vectorstore
203
+
204
+ def delete_doc(delete_name,vectorstore):
205
+ delete_doc_ids = []
206
+ for idx,name in enumerate(vectorstore.get()['metadatas']):
207
+ if name['name'] == delete_name:
208
+ delete_doc_ids.append(vectorstore.get()['ids'][idx])
209
+ for id in delete_doc_ids:
210
+ vectorstore.delete(ids = id)
211
+ # vectorstore.persist()
212
+ retriever = vectorstore.as_retriever(search_kwargs={'k':1})
213
+ return retriever, vectorstore
214
+
215
+ def delete_all_doc(vectorstore):
216
+ delete_doc_ids = vectorstore.get()['ids']
217
+ for id in delete_doc_ids:
218
+ vectorstore.delete(ids = id)
219
+ # vectorstore.persist()
220
+ retriever = vectorstore.as_retriever(search_kwargs={'k':1})
221
+ return retriever, vectorstore