legacy107 commited on
Commit
c4569dd
1 Parent(s): 6e67454

Add paraphrased answer

Browse files
Files changed (1) hide show
  1. app.py +43 -6
app.py CHANGED
@@ -1,23 +1,54 @@
1
  import gradio as gr
2
  from gradio.components import Textbox
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
4
  import torch
5
  import datasets
6
 
7
  # Load your fine-tuned model and tokenizer
8
- model_name = "legacy107/flan-t5-large-bottleneck-adapter-cpgQA-unique"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
11
  model.set_active_adapters("question_answering")
 
 
 
 
 
 
12
  max_length = 512
13
  max_target_length = 128
14
 
15
  # Load your dataset
16
- dataset = datasets.load_dataset("minh21/cpgQA-v1.0-unique-context-test-10-percent", split="test")
17
  dataset = dataset.shuffle()
18
  dataset = dataset.select(range(5))
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Define your function to generate answers
22
  def generate_answer(question, context):
23
  # Combine question and context
@@ -39,7 +70,10 @@ def generate_answer(question, context):
39
  # Decode and return the generated answer
40
  generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
41
 
42
- return generated_answer
 
 
 
43
 
44
 
45
  # Define a function to list examples from the dataset
@@ -59,7 +93,10 @@ iface = gr.Interface(
59
  Textbox(label="Question"),
60
  Textbox(label="Context")
61
  ],
62
- outputs=Textbox(label="Generated Answer"),
 
 
 
63
  examples=list_examples()
64
  )
65
 
 
1
  import gradio as gr
2
  from gradio.components import Textbox
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
+ from peft import PeftModel, PeftConfig
5
  import torch
6
  import datasets
7
 
8
  # Load your fine-tuned model and tokenizer
9
+ model_name = "google/flan-t5-large"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
12
+ model.load_adapter("legacy107/adapter-flan-t5-large-bottleneck-adapter-cpgQA", source="hf")
13
  model.set_active_adapters("question_answering")
14
+
15
+ peft_name = "legacy107/flan-t5-large-ia3-bioasq-paraphrase"
16
+ peft_config = PeftConfig.from_pretrained(peft_name)
17
+ paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
+ paraphrase_model = PeftModel.from_pretrained(paraphrase_model, peft_name)
19
+
20
  max_length = 512
21
  max_target_length = 128
22
 
23
  # Load your dataset
24
+ dataset = datasets.load_dataset("minh21/cpgQA-v1.0-unique-context-test-10-percent-validation-10-percent", split="test")
25
  dataset = dataset.shuffle()
26
  dataset = dataset.select(range(5))
27
 
28
 
29
+ def paraphrase_answer(question, answer):
30
+ # Combine question and context
31
+ input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
32
+
33
+ # Tokenize the input text
34
+ input_ids = tokenizer(
35
+ input_text,
36
+ return_tensors="pt",
37
+ padding="max_length",
38
+ truncation=True,
39
+ max_length=max_length,
40
+ ).input_ids
41
+
42
+ # Generate the answer
43
+ with torch.no_grad():
44
+ generated_ids = paraphrase_model.generate(input_ids, max_new_tokens=max_target_length)
45
+
46
+ # Decode and return the generated answer
47
+ paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
48
+
49
+ return paraphrased_answer
50
+
51
+
52
  # Define your function to generate answers
53
  def generate_answer(question, context):
54
  # Combine question and context
 
70
  # Decode and return the generated answer
71
  generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
72
 
73
+ # Paraphrase answer
74
+ paraphrased_answer = paraphrase_answer(question, generated_answer)
75
+
76
+ return generated_answer, paraphrased_answer
77
 
78
 
79
  # Define a function to list examples from the dataset
 
93
  Textbox(label="Question"),
94
  Textbox(label="Context")
95
  ],
96
+ outputs=[
97
+ Textbox(label="Generated Answer"),
98
+ Textbox(label="Natural Answer")
99
+ ],
100
  examples=list_examples()
101
  )
102