legacy107 commited on
Commit
0811f96
1 Parent(s): 94b48b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from gradio.components import Textbox
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
  from peft import PeftModel
5
  import torch
@@ -19,6 +19,7 @@ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
19
  model_name = "google/flan-t5-large"
20
  peft_name = "legacy107/flan-t5-large-ia3-covidqa"
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
22
  model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
23
  model = PeftModel.from_pretrained(model, peft_name)
24
 
@@ -32,7 +33,7 @@ max_target_length = 200
32
  # Load your dataset
33
  dataset = datasets.load_dataset("minh21/COVID-QA-Chunk-64-testset-biencoder-data-90_10", split="train")
34
  dataset = dataset.shuffle()
35
- dataset = dataset.select(range(5))
36
 
37
  # Context chunking
38
  min_sentences_per_chunk = 3
@@ -138,7 +139,7 @@ def retrieve_context(question, contexts):
138
 
139
 
140
  # Define your function to generate answers
141
- def generate_answer(question, context, ground):
142
  contexts = chunk_splitter(clean_data(context))
143
  context = retrieve_context(question, contexts)
144
 
@@ -162,9 +163,18 @@ def generate_answer(question, context, ground):
162
  generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
163
 
164
  # Paraphrase answer
165
- paraphrased_answer = paraphrase_answer(question, generated_answer)
 
 
166
 
167
- return generated_answer, context, paraphrased_answer
 
 
 
 
 
 
 
168
 
169
 
170
  # Define a function to list examples from the dataset
@@ -174,7 +184,7 @@ def list_examples():
174
  context = example["context"]
175
  question = example["question"]
176
  answer = example["answer"]
177
- examples.append([question, context, answer])
178
  return examples
179
 
180
 
@@ -184,14 +194,18 @@ iface = gr.Interface(
184
  inputs=[
185
  Textbox(label="Question"),
186
  Textbox(label="Context"),
187
- Textbox(label="Ground truth")
 
 
188
  ],
189
  outputs=[
190
  Textbox(label="Generated Answer"),
191
  Textbox(label="Retrieved Context"),
192
- Textbox(label="Natural Answer")
 
193
  ],
194
- examples=list_examples()
 
195
  )
196
 
197
  # Launch the Gradio interface
 
1
  import gradio as gr
2
+ from gradio.components import Textbox, Checkbox
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
  from peft import PeftModel
5
  import torch
 
19
  model_name = "google/flan-t5-large"
20
  peft_name = "legacy107/flan-t5-large-ia3-covidqa"
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ pretrained_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
23
  model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
24
  model = PeftModel.from_pretrained(model, peft_name)
25
 
 
33
  # Load your dataset
34
  dataset = datasets.load_dataset("minh21/COVID-QA-Chunk-64-testset-biencoder-data-90_10", split="train")
35
  dataset = dataset.shuffle()
36
+ dataset = dataset.select(range(10))
37
 
38
  # Context chunking
39
  min_sentences_per_chunk = 3
 
139
 
140
 
141
  # Define your function to generate answers
142
+ def generate_answer(question, context, ground, do_pretrained, do_natural):
143
  contexts = chunk_splitter(clean_data(context))
144
  context = retrieve_context(question, contexts)
145
 
 
163
  generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
164
 
165
  # Paraphrase answer
166
+ paraphrased_answer = ""
167
+ if do_natural:
168
+ paraphrased_answer = paraphrase_answer(question, generated_answer)
169
 
170
+ # Get pretrained model's answer
171
+ pretrained_answer = ""
172
+ if do_pretrained:
173
+ with torch.no_grad():
174
+ pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
175
+ pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
176
+
177
+ return generated_answer, context, paraphrased_answer, pretrained_answer
178
 
179
 
180
  # Define a function to list examples from the dataset
 
184
  context = example["context"]
185
  question = example["question"]
186
  answer = example["answer"]
187
+ examples.append([question, context, answer, True, True])
188
  return examples
189
 
190
 
 
194
  inputs=[
195
  Textbox(label="Question"),
196
  Textbox(label="Context"),
197
+ Textbox(label="Ground truth"),
198
+ Checkbox(label="Include pretrained model's result"),
199
+ Checkbox(label="Include natural answer")
200
  ],
201
  outputs=[
202
  Textbox(label="Generated Answer"),
203
  Textbox(label="Retrieved Context"),
204
+ Textbox(label="Natural Answer"),
205
+ Textbox(label="Pretrained Model's Answer")
206
  ],
207
+ examples=list_examples(),
208
+ examples_per_page=1,
209
  )
210
 
211
  # Launch the Gradio interface