not-lain commited on
Commit
6cc6b0d
β€’
1 Parent(s): c3f7f41

first commit

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +108 -0
  3. requirements.txt +3 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: RAG
3
- emoji: 🐒
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: RAG
3
+ emoji: πŸŒ”WπŸŒ’
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PyPDF2
3
+ import pandas as pd
4
+ import warnings
5
+ import re
6
+ from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
7
+ from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
8
+ import torch
9
+ import gradio as gr
10
+ from typing import Union
11
+ from datasets import Dataset
12
+ warnings.filterwarnings("ignore")
13
+
14
+
15
+
16
+ torch.set_grad_enabled(False)
17
+ ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
18
+ ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
19
+
20
+
21
+
22
+ q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
23
+ q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
24
+
25
+
26
+ def process_pdfs(parent_dir: Union[str,list]):
27
+ """ processess the PDF files and returns a dataframe with the text of each page in a
28
+ different line""" # XD
29
+ # creating a pdf file object
30
+ df = pd.DataFrame(columns = ["title","text"])
31
+ if type(parent_dir) == str :
32
+ parent_dir = [parent_dir]
33
+ for file_path in parent_dir:
34
+ # creating a pdf file object
35
+ pdfFileObj = open(file_path, 'rb')
36
+
37
+ # creating a pdf reader object
38
+ pdfReader = PyPDF2.PdfReader(pdfFileObj)
39
+ # printing number of pages in pdf file
40
+ num_pages = len(pdfReader.pages)
41
+ for i in range(num_pages) :
42
+ pageObj = pdfReader.pages[i]
43
+ # extracting text from page
44
+ txt = pageObj.extract_text()
45
+ txt = txt.replace("\n","") # strip return to line
46
+ txt = txt.replace("\t","") # strip tabs
47
+ txt = re.sub(r" +"," ",txt) # strip extra space
48
+ # 512 is related to the positional encoding "facebook/dpr-ctx_encoder-single-nq-base" model
49
+ if len(txt) < 512 :
50
+ file_name = file_path.split("/")[-1]
51
+ new_data = {"title":f"{file_name}-page-{i}","text":txt}
52
+ df = df.append(new_data,ignore_index=True)
53
+ else :
54
+ while len(txt) > 512 :
55
+ new_data = {"title":f"{file_name}-pg{i}","text":txt[:512]}
56
+ df = df.append(new_data,ignore_index=True)
57
+ txt = txt[512:]
58
+
59
+ # closing the pdf file object
60
+ pdfFileObj.close()
61
+ return df
62
+
63
+ def process(example):
64
+ """process the bathces of the dataset and returns the embeddings"""
65
+ tokens = ctx_tokenizer(example["text"], return_tensors="pt")
66
+ embed = ctx_encoder(**tokens)[0][0].numpy()
67
+ return {'embeddings': embed}
68
+
69
+ def process_dataset(df):
70
+ """processess the dataframe and returns a dataset variable"""
71
+ ds = Dataset.from_pandas(df)
72
+ ds = ds.map(process)
73
+ ds.add_faiss_index(column='embeddings') # add faiss index
74
+ return ds
75
+
76
+ def search(query, ds, k=3):
77
+ """searches the query in the dataset and returns the k most similar"""
78
+ tokens = q_tokenizer(query, return_tensors="pt")
79
+ query_embed = q_encoder(**tokens)[0][0].numpy()
80
+ scores, retrieved_examples = ds.get_nearest_examples("embeddings", query_embed, k=k)
81
+ out = f"""title : {retrieved_examples["title"][0]},\ncontent: {retrieved_examples["text"][0]}
82
+ similar resources: {retrieved_examples["title"]}
83
+ """
84
+ return out
85
+
86
+ def predict(query,file_paths, k=3):
87
+ """predicts the most similar files to the query"""
88
+ df = process_pdfs(file_paths)
89
+ ds = process_dataset(df)
90
+ return search(query,ds,k=k)
91
+
92
+ with gr.Blocks() as demo :
93
+ with gr.Column():
94
+ files = gr.Files(label="Upload PDFs",type="filepath",file_count="multiple")
95
+ query = gr.Text(label="query")
96
+ with gr.Accordion():
97
+ k = gr.Number(label="number of results",value=3)
98
+ button = gr.Button("search")
99
+ with gr.Column():
100
+ output = gr.Textbox(label="output")
101
+ button.click(predict, [query,files,k],outputs=output)
102
+
103
+ demo.launch()
104
+
105
+
106
+
107
+
108
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ datasets
2
+ PyPDF2
3
+ torch