Spaces:
Runtime error
Runtime error
File size: 7,929 Bytes
e20c0ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import requests
from bs4 import BeautifulSoup
import sys
from transformers import BartForConditionalGeneration, BartTokenizer
import re
import torch
import wikipediaapi
import gradio as gr
import random
# User key for the Wikipedia API
wiki_wiki = wikipediaapi.Wikipedia('MCQ Generation (r.j.a.lemein@student.rug.nl)', 'en')
# Load the model from the saved path
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_WEB_large = BartForConditionalGeneration.from_pretrained('rizkiduwinanto/WEB_large').to(device)
model_name_large = "facebook/bart-large"
model_QA_large = BartForConditionalGeneration.from_pretrained("b-b-brouwer/CL_large").to(device)
tokenizer_large = BartTokenizer.from_pretrained(model_name_large)
def generate_QA(input_text):
# Tokenize the input text
inputs = tokenizer_large(input_text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]
# Generate outputs
outputs = model_QA_large.generate(input_ids=input_ids, max_length=1024)
# Decode the generated outputs
generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)
question_pattern = r"Question: (.+?\?)"
answer_pattern = r"Answer: (.+?)(?= Distractor\d+: |$)"
dis1_pattern = r"Distractor1: (.+)"
dis2_pattern = r"Distractor2: (.+)"
dis3_pattern = r"Distractor3: (.+)"
question = ""
answer = ""
question_match = re.search(question_pattern, generated_text)
answer_match = re.search(answer_pattern, generated_text)
if question_match:
question = question_match.group(1)
if answer_match:
answer = answer_match.group(1)
distractor_pattern = r"Distractor\d+: (.+?)(?= Distractor|$)"
distractors = re.findall(distractor_pattern, generated_text)
return question, answer, distractors
def generate_support(url, topic, direct_wiki=False, not_print=False):
# Obtain input
if direct_wiki==False:
text = extract_text_from_url(url)
filtered_text = filter_paragraphs_by_topic(text, topic)
elif direct_wiki==True:
display_title = url.split("/")[-1]
if not_print == False:
print(display_title)
page = wiki_wiki.page(display_title)
plain_text_content = page.text
filtered_text = filter_paragraphs_by_topic(plain_text_content, topic)
text_label = "Text: "
answer_label = "Answer: "
# Now these texts are inputs, and supports are outputs
merged_column_input = f"{answer_label} {topic} {text_label} {filtered_text}"
# Tokenize the filtered text. Note that large has the topic as an input as well!
inputs = tokenizer_large(merged_column_input, return_tensors="pt", max_length=1024, truncation=True).to(device)
input_ids = inputs["input_ids"]
# Generate outputs
outputs = model_WEB_large.generate(input_ids=input_ids, max_length=1024)
# Decode the generated outputs
generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)
# Print the generated text
if not_print == False:
print(f"Output of url to support generator: {generated_text}.\n")
# Extract the support
pattern = r'Support:.*$'
# Use regular expression to find the question
matcher = re.search(pattern, generated_text)
if matcher:
support = matcher.group(0)
support.strip()
if not_print == False:
print(f"Found support in output: {support}.\n") # Remove leading/trailing whitespace
return support
else:
if not_print == False:
print("No support found in the output string. The next model will be fed with all generated text.\n This might cause strange results.")
return generated_text
def extract_text_from_url(url):
try:
response = requests.get(url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
# Extract text from all paragraph tags
paragraphs = [p.get_text() for p in soup.find_all('p')]
return ' '.join(paragraphs) # Concatenate paragraphs into a single text
else:
print("Failed to fetch URL:", url)
sys.stdout.flush()
return None
except Exception as e:
print("Error occurred while fetching URL:", str(e))
sys.stdout.flush()
return None
def filter_paragraphs_by_topic(text, topic):
relevant_paragraphs = []
try:
# Split the text into paragraphs based on newline characters
paragraphs = text.split('\n')
# Simple, simply keep the paragraph if it contains the topic
for paragraph in paragraphs:
if topic.lower() in paragraph.lower():
relevant_paragraphs.append(paragraph)
# However, simple doesn't always work, so
if len(relevant_paragraphs) == 0:
topics = topic.split(' ')
for split_topic in topics:
for paragraph in paragraphs:
if split_topic.lower() in paragraph.lower() and paragraph not in relevant_paragraphs:
relevant_paragraphs.append(paragraph)
except:
print("No text was fetched for this topic, as no page was found, so we cannot filter it.")
sys.stdout.flush()
return None
return ' '.join(relevant_paragraphs)
def generate_QA_from_url(url, topic, direct_wiki=False, not_print_support=True):
return generate_QA(generate_support(url, topic, direct_wiki=direct_wiki, not_print=not_print_support))
def prompt(url, topic):
question, answer, distractors = generate_QA_from_url(url, topic)
distractors.append(answer)
random.shuffle(distractors)
return question, answer, distractors[0], distractors[1], distractors[2], distractors[3]
with gr.Blocks() as demo:
answer = gr.State('')
d1 = gr.State('')
d2 = gr.State('')
d3 = gr.State('')
d4 = gr.State('')
url = gr.Text(label="URL")
topic = gr.Text(label="Topic")
submit = gr.Button("Submit")
question = gr.Textbox(label="question")
choice1 = gr.Button(value="choice1", interactive=True, visible=False)
choice2 = gr.Button(value="choice2", interactive=True, visible=False)
choice3 = gr.Button(value="choice3", interactive=True, visible=False)
choice4 = gr.Button(value="choice4", interactive=True, visible=False)
res = gr.Text(value="correct", label="Results", interactive=True, visible=False)
def check(correct_answer, answer):
if answer == correct_answer:
return gr.Text(value='correct', label="result", interactive=True, visible=True)
else:
return gr.Text(value='wrong', label="result", interactive=True, visible=True)
def on_submit(url, topic):
question_text, answer, choice1_text, choice2_text, choice3_text, choice4_text = prompt(url, topic)
choice1 = gr.Button(value=choice1_text, interactive=True, visible=True)
choice2 = gr.Button(value=choice2_text, interactive=True, visible=True)
choice3 = gr.Button(value=choice3_text, interactive=True, visible=True)
choice4 = gr.Button(value=choice4_text, interactive=True, visible=True)
return question_text, answer, choice1, choice2, choice3, choice4, choice1_text, choice2_text, choice3_text, choice4_text
gr.on(
triggers=[submit.click],
fn=on_submit,
inputs=[url, topic],
outputs=[question, answer, choice1, choice2, choice3, choice4, d1, d2, d3, d4]
)
choice1.click(check, inputs=[answer, d1], outputs=res)
choice2.click(check, inputs=[answer, d2], outputs=res)
choice3.click(check, inputs=[answer, d3], outputs=res)
choice4.click(check, inputs=[answer, d4], outputs=res)
demo.launch(share=True) |