legacy107 commited on
Commit
4723439
1 Parent(s): 1d6aee5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -8
app.py CHANGED
@@ -5,6 +5,10 @@ from peft import PeftModel
5
  import torch
6
  import datasets
7
  from sentence_transformers import CrossEncoder
 
 
 
 
8
 
9
  # Load cross encoder
10
  top_k = 10
@@ -29,6 +33,69 @@ dataset = datasets.load_dataset("minh21/COVID-QA-Chunk-64-testset-biencoder-data
29
  dataset = dataset.shuffle()
30
  dataset = dataset.select(range(5))
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def paraphrase_answer(question, answer):
34
  # Combine question and context
@@ -70,9 +137,8 @@ def retrieve_context(question, contexts):
70
 
71
 
72
  # Define your function to generate answers
73
- def generate_answer(question, context, contexts):
74
- if type(contexts) is str:
75
- contexts = contexts.split(',')
76
  context = retrieve_context(question, contexts)
77
 
78
  # Combine question and context
@@ -97,7 +163,7 @@ def generate_answer(question, context, contexts):
97
  # Paraphrase answer
98
  paraphrased_answer = paraphrase_answer(question, generated_answer)
99
 
100
- return generated_answer, paraphrased_answer
101
 
102
 
103
  # Define a function to list examples from the dataset
@@ -105,9 +171,8 @@ def list_examples():
105
  examples = []
106
  for example in dataset:
107
  context = example["context"]
108
- contexts = example["context_chunks"]
109
  question = example["question"]
110
- examples.append([question, context, contexts])
111
  return examples
112
 
113
 
@@ -116,11 +181,11 @@ iface = gr.Interface(
116
  fn=generate_answer,
117
  inputs=[
118
  Textbox(label="Question"),
119
- Textbox(label="Context"),
120
- Textbox(label="Contexts")
121
  ],
122
  outputs=[
123
  Textbox(label="Generated Answer"),
 
124
  Textbox(label="Natural Answer")
125
  ],
126
  examples=list_examples()
 
5
  import torch
6
  import datasets
7
  from sentence_transformers import CrossEncoder
8
+ import re
9
+ from nltk import sent_tokenize, word_tokenize
10
+ import nltk
11
+ nltk.download('punkt')
12
 
13
  # Load cross encoder
14
  top_k = 10
 
33
  dataset = dataset.shuffle()
34
  dataset = dataset.select(range(5))
35
 
36
+ # Context chunking
37
+ min_sentences_per_chunk = 3
38
+ chunk_size = 64
39
+ window_size = math.ceil(min_sentences_per_chunk * 0.25)
40
+ over_lap_chunk_size = chunk_size * 0.25
41
+
42
+ def chunk_splitter(context):
43
+ sentences = sent_tokenize(context)
44
+ chunks = []
45
+ current_chunk = []
46
+
47
+ for sentence in sentences:
48
+ if len(current_chunk) < min_sentences_per_chunk:
49
+ current_chunk.append(sentence)
50
+ continue
51
+ elif len(word_tokenize(' '.join(current_chunk) + " " + sentence)) < chunk_size:
52
+ current_chunk.append(sentence)
53
+ continue
54
+
55
+ chunks.append(' '.join(current_chunk))
56
+ new_chunk = current_chunk[-window_size:]
57
+ new_window = window_size
58
+ buffer_new_chunk = new_chunk
59
+
60
+ while len(word_tokenize(' '.join(new_chunk))) <= over_lap_chunk_size:
61
+ buffer_new_chunk = new_chunk
62
+ new_window += 1
63
+ new_chunk = current_chunk[-new_window:]
64
+ if new_window >= len(current_chunk):
65
+ break
66
+
67
+ current_chunk = buffer_new_chunk
68
+ current_chunk.append(sentence)
69
+
70
+
71
+ if current_chunk:
72
+ chunks.append(' '.join(current_chunk))
73
+
74
+ return chunks
75
+
76
+
77
+ def clean_data(text):
78
+ # Extract abstract content
79
+ index = text.find("\nAbstract: ")
80
+ if index != -1:
81
+ cleaned_text = text[index + len("\nAbstract: "):]
82
+ else:
83
+ cleaned_text = text # If "\nAbstract: " is not found, keep the original text
84
+
85
+ # Remove both http and https links using a regular expression
86
+ cleaned_text = re.sub(r'(http(s|)\/\/:( |)\S+)|(http(s|):\/\/( |)\S+)', '', cleaned_text)
87
+
88
+
89
+ # Remove DOI patterns like "doi:10.1371/journal.pone.0007211.s003"
90
+ cleaned_text = re.sub(r'doi:( |)\w+', '', cleaned_text)
91
+
92
+ # Remove the "(0.11 MB DOC)" pattern
93
+ cleaned_text = re.sub(r'\(0\.\d+ MB DOC\)', '', cleaned_text)
94
+
95
+ cleaned_text = re.sub(r'www\.\w+(.org|)', '', cleaned_text)
96
+
97
+ return cleaned_text
98
+
99
 
100
  def paraphrase_answer(question, answer):
101
  # Combine question and context
 
137
 
138
 
139
  # Define your function to generate answers
140
+ def generate_answer(question, context):
141
+ contexts = chunk_splitter(clean_data(context))
 
142
  context = retrieve_context(question, contexts)
143
 
144
  # Combine question and context
 
163
  # Paraphrase answer
164
  paraphrased_answer = paraphrase_answer(question, generated_answer)
165
 
166
+ return generated_answer, context, paraphrased_answer
167
 
168
 
169
  # Define a function to list examples from the dataset
 
171
  examples = []
172
  for example in dataset:
173
  context = example["context"]
 
174
  question = example["question"]
175
+ examples.append([question, context])
176
  return examples
177
 
178
 
 
181
  fn=generate_answer,
182
  inputs=[
183
  Textbox(label="Question"),
184
+ Textbox(label="Context")
 
185
  ],
186
  outputs=[
187
  Textbox(label="Generated Answer"),
188
+ Textbox(label="Retrieved Context"),
189
  Textbox(label="Natural Answer")
190
  ],
191
  examples=list_examples()