rtabrizi commited on
Commit
edbedf3
1 Parent(s): a2ea59f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -29
app.py CHANGED
@@ -67,16 +67,16 @@ class Retriever:
67
  def load_chunks(self):
68
  self.text = self.extract_text_from_pdf(self.file_path)
69
  text_splitter = RecursiveCharacterTextSplitter(
70
- chunk_size=300,
71
  chunk_overlap=20,
72
  length_function=self.token_len,
73
- separators=["\n\n", " ", ".", ""]
74
  )
75
 
76
  self.chunks = text_splitter.split_text(self.text)
77
 
78
  def load_context_embeddings(self):
79
- encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=100).to(device)
80
 
81
  with torch.no_grad():
82
  model_output = self.context_model(**encoded_input)
@@ -89,20 +89,16 @@ class Retriever:
89
  encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
- model_output = self.question_model(**encoded_query)
93
- query_vector = model_output.pooler_output
94
 
95
  query_vector_np = query_vector.cpu().numpy()
96
  D, I = self.index.search(query_vector_np, k)
97
 
98
- retrieved_texts = [self.chunks[i] for i in I[0]]
99
 
100
  scores = [d for d in D[0]]
101
 
102
- # print("Top 5 retrieved texts and their associated scores:")
103
- # for idx, (text, score) in enumerate(zip(retrieved_texts, scores)):
104
- # print(f"{idx + 1}. Text: {text} \n Score: {score:.4f}\n")
105
-
106
  return retrieved_texts
107
 
108
  class RAG:
@@ -115,22 +111,23 @@ class RAG:
115
 
116
  # generator_name = "valhalla/bart-large-finetuned-squadv1"
117
  # generator_name = "'vblagoje/bart_lfqa'"
118
- generator_name = "a-ware/bart-squadv2"
119
-
120
  self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
121
  self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
122
 
 
 
 
 
 
 
123
  self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
124
  self.retriever.load_chunks()
125
  self.retriever.load_context_embeddings()
126
 
127
- def get_answer(self, question, context):
128
- input_text = "context: %s <question for context: %s </s>" % (context,question)
129
- features = self.generator_tokenizer([input_text], return_tensors='pt')
130
- out = self.generator_model.generate(input_ids=features['input_ids'].to(device), attention_mask=features['attention_mask'].to(device))
131
- return self.generator_tokenizer.decode(out[0])
132
 
133
- def query(self, question):
134
  context = self.retriever.retrieve_top_k(question, k=5)
135
  # input_text = question + " " + " ".join(context)
136
 
@@ -144,22 +141,46 @@ class RAG:
144
  answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
145
  return answer
146
 
 
 
 
147
 
148
- context_model_name="facebook/dpr-ctx_encoder-single-nq-base"
149
- context_model_name="facebook/dpr-ctx_encoder-multiset-base"
150
- question_model_name="facebook/dpr-question_encoder-multiset-base"
151
 
152
- rag = RAG(file_path, device)
 
 
 
 
 
 
153
 
154
- query = "what is the benefit of using multiple attention heads in mult-head attention?"
 
155
 
156
- print(rag.query(query))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  st.title("RAG Model Query Interface")
159
 
160
- query = st.text_area("Enter your question:")
 
 
161
 
162
- # If a query is given, get the answer
163
- if query:
164
- answer = rag.query(query)
165
- st.write(answer)
 
67
  def load_chunks(self):
68
  self.text = self.extract_text_from_pdf(self.file_path)
69
  text_splitter = RecursiveCharacterTextSplitter(
70
+ chunk_size=150,
71
  chunk_overlap=20,
72
  length_function=self.token_len,
73
+ separators=["Section", "\n\n", "\n", ".", " ", ""]
74
  )
75
 
76
  self.chunks = text_splitter.split_text(self.text)
77
 
78
  def load_context_embeddings(self):
79
+ encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=300).to(device)
80
 
81
  with torch.no_grad():
82
  model_output = self.context_model(**encoded_input)
 
89
  encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
+ model_output = self.question_model(**encoded_query)
93
+ query_vector = model_output.pooler_output
94
 
95
  query_vector_np = query_vector.cpu().numpy()
96
  D, I = self.index.search(query_vector_np, k)
97
 
98
+ retrieved_texts = [' '.join(self.chunks[i].split('\n')) for i in I[0]] # Replacing newlines with spaces
99
 
100
  scores = [d for d in D[0]]
101
 
 
 
 
 
102
  return retrieved_texts
103
 
104
  class RAG:
 
111
 
112
  # generator_name = "valhalla/bart-large-finetuned-squadv1"
113
  # generator_name = "'vblagoje/bart_lfqa'"
114
+ # generator_name = "a-ware/bart-squadv2"
115
+
116
  self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
117
  self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
118
 
119
+ # generator_name = "MaRiOrOsSi/t5-base-finetuned-question-answering"
120
+ # generator_name = "t5-small"
121
+
122
+ # self.generator_tokenizer = T5Tokenizer.from_pretrained(generator_name)
123
+ # self.generator_model = T5ForConditionalGeneration.from_pretrained(generator_name)
124
+
125
  self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
126
  self.retriever.load_chunks()
127
  self.retriever.load_context_embeddings()
128
 
 
 
 
 
 
129
 
130
+ def abstractive_query(self, question):
131
  context = self.retriever.retrieve_top_k(question, k=5)
132
  # input_text = question + " " + " ".join(context)
133
 
 
141
  answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
142
  return answer
143
 
144
+ def extractive_query(self, question):
145
+ context = self.retriever.retrieve_top_k(question, k=15)
146
+ generator_name = "valhalla/bart-large-finetuned-squadv1"
147
 
148
+ self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name)
149
+ self.generator_model = BartForQuestionAnswering.from_pretrained(generator_name).to(device)
 
150
 
151
+ inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=200 , padding="max_length")
152
+ with torch.no_grad():
153
+ model_inputs = inputs.to(device)
154
+ outputs = self.generator_model(**model_inputs)
155
+
156
+ answer_start_index = outputs.start_logits.argmax()
157
+ answer_end_index = outputs.end_logits.argmax()
158
 
159
+ if answer_end_index < answer_start_index:
160
+ answer_start_index, answer_end_index = answer_end_index, answer_start_index
161
 
162
+ print(answer_start_index, answer_end_index)
163
+
164
+ predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
165
+ answer = self.generator_tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
166
+ answer = answer.replace('\n', ' ').strip()
167
+ answer = answer.replace('$', '')
168
+
169
+ return answer
170
+
171
+ context_model_name="facebook/dpr-ctx_encoder-single-nq-base"
172
+ question_model_name = "facebook/dpr-question_encoder-single-nq-base"
173
+ # context_model_name="facebook/dpr-ctx_encoder-multiset-base"
174
+ # question_model_name="facebook/dpr-question_encoder-multiset-base"
175
+
176
+ rag = RAG(file_path, device)
177
 
178
  st.title("RAG Model Query Interface")
179
 
180
+ # offer to ask a question and get an answer. make it so they can ask as many questions as they want
181
+
182
+ question = st.text_input("Ask a question", "What is another name for self-attention?")
183
 
184
+ if st.button("Ask"):
185
+ answer = rag.extractive_query(question)
186
+ st.write(answer)