legacy107 commited on
Commit
b4a7297
1 Parent(s): 1ecf571

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.components import Textbox
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
+ from peft import PeftModel
5
+ import torch
6
+ import datasets
7
+ from sentence_transformers import CrossEncoder
8
+
9
+ # Load cross encoder
10
+ top_k = 10
11
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
12
+
13
+ # Load your fine-tuned model and tokenizer
14
+ model_name = "google/flan-t5-large"
15
+ peft_name = "legacy107/flan-t5-large-ia3-covidqa"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
18
+ model = PeftModel.from_pretrained(model, peft_name)
19
+
20
+ peft_name = "legacy107/flan-t5-large-ia3-bioasq-paraphrase"
21
+ paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
+ paraphrase_model = PeftModel.from_pretrained(paraphrase_model, peft_name)
23
+
24
+ max_length = 512
25
+ max_target_length = 200
26
+
27
+ # Load your dataset
28
+ dataset = datasets.load_dataset("minh21/COVID-QA-Chunk-64-testset-biencoder-data-90_10", split="test")
29
+ dataset = dataset.shuffle()
30
+ dataset = dataset.select(range(5))
31
+
32
+
33
+ def paraphrase_answer(question, answer):
34
+ # Combine question and context
35
+ input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
36
+
37
+ # Tokenize the input text
38
+ input_ids = tokenizer(
39
+ input_text,
40
+ return_tensors="pt",
41
+ padding="max_length",
42
+ truncation=True,
43
+ max_length=max_length,
44
+ ).input_ids
45
+
46
+ # Generate the answer
47
+ with torch.no_grad():
48
+ generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
49
+
50
+ # Decode and return the generated answer
51
+ paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
52
+
53
+ return paraphrased_answer
54
+
55
+
56
+ def retrieve_context(question, contexts):
57
+ # cross-encoder
58
+ hits = [{"corpus_id": i} for i in range(len(contexts))]
59
+ cross_inp = [[question, contexts[hit["corpus_id"]]] for hit in hits]
60
+ cross_scores = cross_encoder.predict(cross_inp, show_progress_bar=False)
61
+
62
+ for idx in range(len(cross_scores)):
63
+ hits[idx]["cross-score"] = cross_scores[idx]
64
+
65
+ hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
66
+
67
+ return " ".join(
68
+ [contexts[hit["corpus_id"]] for hit in hits[0:top_k]]
69
+ ).replace("\n", " ")
70
+
71
+
72
+ # Define your function to generate answers
73
+ def generate_answer(question, context, contexts):
74
+ context = retrieve_context(question, contexts)
75
+
76
+ # Combine question and context
77
+ input_text = f"question: {question} context: {context}"
78
+
79
+ # Tokenize the input text
80
+ input_ids = tokenizer(
81
+ input_text,
82
+ return_tensors="pt",
83
+ padding="max_length",
84
+ truncation=True,
85
+ max_length=max_length,
86
+ ).input_ids
87
+
88
+ # Generate the answer
89
+ with torch.no_grad():
90
+ generated_ids = model.generate(input_ids, max_new_tokens=max_target_length)
91
+
92
+ # Decode and return the generated answer
93
+ generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
94
+
95
+ # Paraphrase answer
96
+ paraphrased_answer = paraphrase_answer(question, generated_answer)
97
+
98
+ return generated_answer, paraphrased_answer
99
+
100
+
101
+ # Define a function to list examples from the dataset
102
+ def list_examples():
103
+ examples = []
104
+ for example in dataset:
105
+ context = example["context"]
106
+ contexts = example["context_chunks"]
107
+ question = example["question"]
108
+ examples.append([question, context, contexts])
109
+ return examples
110
+
111
+
112
+ # Create a Gradio interface
113
+ iface = gr.Interface(
114
+ fn=generate_answer,
115
+ inputs=[
116
+ Textbox(label="Question"),
117
+ Textbox(label="Context")
118
+ Textbox(label="Contexts")
119
+ ],
120
+ outputs=[
121
+ Textbox(label="Generated Answer"),
122
+ Textbox(label="Natural Answer")
123
+ ],
124
+ examples=list_examples()
125
+ )
126
+
127
+ # Launch the Gradio interface
128
+ iface.launch()