Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from gradio.components import Textbox
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
|
4 |
from peft import PeftModel
|
5 |
import torch
|
@@ -19,6 +19,7 @@ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
19 |
model_name = "google/flan-t5-large"
|
20 |
peft_name = "legacy107/flan-t5-large-ia3-covidqa"
|
21 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
22 |
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
|
23 |
model = PeftModel.from_pretrained(model, peft_name)
|
24 |
|
@@ -32,7 +33,7 @@ max_target_length = 200
|
|
32 |
# Load your dataset
|
33 |
dataset = datasets.load_dataset("minh21/COVID-QA-Chunk-64-testset-biencoder-data-90_10", split="train")
|
34 |
dataset = dataset.shuffle()
|
35 |
-
dataset = dataset.select(range(
|
36 |
|
37 |
# Context chunking
|
38 |
min_sentences_per_chunk = 3
|
@@ -138,7 +139,7 @@ def retrieve_context(question, contexts):
|
|
138 |
|
139 |
|
140 |
# Define your function to generate answers
|
141 |
-
def generate_answer(question, context, ground):
|
142 |
contexts = chunk_splitter(clean_data(context))
|
143 |
context = retrieve_context(question, contexts)
|
144 |
|
@@ -162,9 +163,18 @@ def generate_answer(question, context, ground):
|
|
162 |
generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
163 |
|
164 |
# Paraphrase answer
|
165 |
-
paraphrased_answer =
|
|
|
|
|
166 |
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
|
170 |
# Define a function to list examples from the dataset
|
@@ -174,7 +184,7 @@ def list_examples():
|
|
174 |
context = example["context"]
|
175 |
question = example["question"]
|
176 |
answer = example["answer"]
|
177 |
-
examples.append([question, context, answer])
|
178 |
return examples
|
179 |
|
180 |
|
@@ -184,14 +194,18 @@ iface = gr.Interface(
|
|
184 |
inputs=[
|
185 |
Textbox(label="Question"),
|
186 |
Textbox(label="Context"),
|
187 |
-
Textbox(label="Ground truth")
|
|
|
|
|
188 |
],
|
189 |
outputs=[
|
190 |
Textbox(label="Generated Answer"),
|
191 |
Textbox(label="Retrieved Context"),
|
192 |
-
Textbox(label="Natural Answer")
|
|
|
193 |
],
|
194 |
-
examples=list_examples()
|
|
|
195 |
)
|
196 |
|
197 |
# Launch the Gradio interface
|
|
|
1 |
import gradio as gr
|
2 |
+
from gradio.components import Textbox, Checkbox
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
|
4 |
from peft import PeftModel
|
5 |
import torch
|
|
|
19 |
model_name = "google/flan-t5-large"
|
20 |
peft_name = "legacy107/flan-t5-large-ia3-covidqa"
|
21 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
+
pretrained_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
|
23 |
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
|
24 |
model = PeftModel.from_pretrained(model, peft_name)
|
25 |
|
|
|
33 |
# Load your dataset
|
34 |
dataset = datasets.load_dataset("minh21/COVID-QA-Chunk-64-testset-biencoder-data-90_10", split="train")
|
35 |
dataset = dataset.shuffle()
|
36 |
+
dataset = dataset.select(range(10))
|
37 |
|
38 |
# Context chunking
|
39 |
min_sentences_per_chunk = 3
|
|
|
139 |
|
140 |
|
141 |
# Define your function to generate answers
|
142 |
+
def generate_answer(question, context, ground, do_pretrained, do_natural):
|
143 |
contexts = chunk_splitter(clean_data(context))
|
144 |
context = retrieve_context(question, contexts)
|
145 |
|
|
|
163 |
generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
164 |
|
165 |
# Paraphrase answer
|
166 |
+
paraphrased_answer = ""
|
167 |
+
if do_natural:
|
168 |
+
paraphrased_answer = paraphrase_answer(question, generated_answer)
|
169 |
|
170 |
+
# Get pretrained model's answer
|
171 |
+
pretrained_answer = ""
|
172 |
+
if do_pretrained:
|
173 |
+
with torch.no_grad():
|
174 |
+
pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
|
175 |
+
pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
|
176 |
+
|
177 |
+
return generated_answer, context, paraphrased_answer, pretrained_answer
|
178 |
|
179 |
|
180 |
# Define a function to list examples from the dataset
|
|
|
184 |
context = example["context"]
|
185 |
question = example["question"]
|
186 |
answer = example["answer"]
|
187 |
+
examples.append([question, context, answer, True, True])
|
188 |
return examples
|
189 |
|
190 |
|
|
|
194 |
inputs=[
|
195 |
Textbox(label="Question"),
|
196 |
Textbox(label="Context"),
|
197 |
+
Textbox(label="Ground truth"),
|
198 |
+
Checkbox(label="Include pretrained model's result"),
|
199 |
+
Checkbox(label="Include natural answer")
|
200 |
],
|
201 |
outputs=[
|
202 |
Textbox(label="Generated Answer"),
|
203 |
Textbox(label="Retrieved Context"),
|
204 |
+
Textbox(label="Natural Answer"),
|
205 |
+
Textbox(label="Pretrained Model's Answer")
|
206 |
],
|
207 |
+
examples=list_examples(),
|
208 |
+
examples_per_page=1,
|
209 |
)
|
210 |
|
211 |
# Launch the Gradio interface
|