rtabrizi commited on
Commit
c7191ea
1 Parent(s): 6a8529c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -21
app.py CHANGED
@@ -97,30 +97,24 @@ class Retriever:
97
 
98
  retrieved_texts = [' '.join(self.chunks[i].split('\n')) for i in I[0]] # Replacing newlines with spaces
99
 
100
- scores = [d for d in D[0]]
101
-
102
  return retrieved_texts
103
 
 
104
  class RAG:
105
  def __init__(self,
106
  file_path,
107
  device,
108
  context_model_name="facebook/dpr-ctx_encoder-multiset-base",
109
  question_model_name="facebook/dpr-question_encoder-multiset-base",
110
- generator_name="facebook/bart-large"):
111
 
112
  # generator_name = "valhalla/bart-large-finetuned-squadv1"
113
  # generator_name = "'vblagoje/bart_lfqa'"
114
  # generator_name = "a-ware/bart-squadv2"
115
 
116
- self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
117
- self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)
118
-
119
- # generator_name = "MaRiOrOsSi/t5-base-finetuned-question-answering"
120
- # generator_name = "t5-small"
121
-
122
- # self.generator_tokenizer = T5Tokenizer.from_pretrained(generator_name)
123
- # self.generator_model = T5ForConditionalGeneration.from_pretrained(generator_name)
124
 
125
  self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
126
  self.retriever.load_chunks()
@@ -128,8 +122,9 @@ class RAG:
128
 
129
 
130
  def abstractive_query(self, question):
 
 
131
  context = self.retriever.retrieve_top_k(question, k=5)
132
- # input_text = question + " " + " ".join(context)
133
 
134
  input_text = "answer: " + " ".join(context) + " " + question
135
 
@@ -141,12 +136,9 @@ class RAG:
141
 
142
  def extractive_query(self, question):
143
  context = self.retriever.retrieve_top_k(question, k=15)
144
- generator_name = "valhalla/bart-large-finetuned-squadv1"
145
-
146
- self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name)
147
- self.generator_model = BartForQuestionAnswering.from_pretrained(generator_name).to(device)
148
-
149
- inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=200 , padding="max_length")
150
  with torch.no_grad():
151
  model_inputs = inputs.to(device)
152
  outputs = self.generator_model(**model_inputs)
@@ -163,11 +155,9 @@ class RAG:
163
  answer = answer.replace('$', '')
164
 
165
  return answer
166
-
167
  context_model_name="facebook/dpr-ctx_encoder-single-nq-base"
168
  question_model_name = "facebook/dpr-question_encoder-single-nq-base"
169
- # context_model_name="facebook/dpr-ctx_encoder-multiset-base"
170
- # question_model_name="facebook/dpr-question_encoder-multiset-base"
171
 
172
  rag = RAG(file_path, device)
173
 
 
97
 
98
  retrieved_texts = [' '.join(self.chunks[i].split('\n')) for i in I[0]] # Replacing newlines with spaces
99
 
 
 
100
  return retrieved_texts
101
 
102
+
103
  class RAG:
104
  def __init__(self,
105
  file_path,
106
  device,
107
  context_model_name="facebook/dpr-ctx_encoder-multiset-base",
108
  question_model_name="facebook/dpr-question_encoder-multiset-base",
109
+ generator_name="valhalla/bart-large-finetuned-squadv1"):
110
 
111
  # generator_name = "valhalla/bart-large-finetuned-squadv1"
112
  # generator_name = "'vblagoje/bart_lfqa'"
113
  # generator_name = "a-ware/bart-squadv2"
114
 
115
+ generator_name = "valhalla/bart-large-finetuned-squadv1"
116
+ self.generator_tokenizer = AutoTokenizer.from_pretrained(generator_name)
117
+ self.generator_model = BartForQuestionAnswering.from_pretrained(generator_name).to(device)
 
 
 
 
 
118
 
119
  self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
120
  self.retriever.load_chunks()
 
122
 
123
 
124
  def abstractive_query(self, question):
125
+ self.generator_tokenizer = BartTokenizer.from_pretrained(self.generator_name)
126
+ self.generator_model = BartForConditionalGeneration.from_pretrained(self.generator_name).to(device)
127
  context = self.retriever.retrieve_top_k(question, k=5)
 
128
 
129
  input_text = "answer: " + " ".join(context) + " " + question
130
 
 
136
 
137
  def extractive_query(self, question):
138
  context = self.retriever.retrieve_top_k(question, k=15)
139
+
140
+
141
+ inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=300 , padding="max_length")
 
 
 
142
  with torch.no_grad():
143
  model_inputs = inputs.to(device)
144
  outputs = self.generator_model(**model_inputs)
 
155
  answer = answer.replace('$', '')
156
 
157
  return answer
158
+
159
  context_model_name="facebook/dpr-ctx_encoder-single-nq-base"
160
  question_model_name = "facebook/dpr-question_encoder-single-nq-base"
 
 
161
 
162
  rag = RAG(file_path, device)
163