Spaces:
Runtime error
Runtime error
import requests | |
from bs4 import BeautifulSoup | |
import sys | |
from transformers import BartForConditionalGeneration, BartTokenizer | |
import re | |
import torch | |
import wikipediaapi | |
import gradio as gr | |
import random | |
# User key for the Wikipedia API | |
wiki_wiki = wikipediaapi.Wikipedia('MCQ Generation (r.j.a.lemein@student.rug.nl)', 'en') | |
# Load the model from the saved path | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model_WEB_large = BartForConditionalGeneration.from_pretrained('rizkiduwinanto/WEB_large').to(device) | |
model_name_large = "facebook/bart-large" | |
model_QA_large = BartForConditionalGeneration.from_pretrained("b-b-brouwer/CL_large").to(device) | |
tokenizer_large = BartTokenizer.from_pretrained(model_name_large) | |
def generate_QA(input_text): | |
# Tokenize the input text | |
inputs = tokenizer_large(input_text, return_tensors="pt").to(device) | |
input_ids = inputs["input_ids"] | |
# Generate outputs | |
outputs = model_QA_large.generate(input_ids=input_ids, max_length=1024) | |
# Decode the generated outputs | |
generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True) | |
question_pattern = r"Question: (.+?\?)" | |
answer_pattern = r"Answer: (.+?)(?= Distractor\d+: |$)" | |
dis1_pattern = r"Distractor1: (.+)" | |
dis2_pattern = r"Distractor2: (.+)" | |
dis3_pattern = r"Distractor3: (.+)" | |
question = "" | |
answer = "" | |
question_match = re.search(question_pattern, generated_text) | |
answer_match = re.search(answer_pattern, generated_text) | |
if question_match: | |
question = question_match.group(1) | |
if answer_match: | |
answer = answer_match.group(1) | |
distractor_pattern = r"Distractor\d+: (.+?)(?= Distractor|$)" | |
distractors = re.findall(distractor_pattern, generated_text) | |
return question, answer, distractors | |
def generate_support(url, topic, direct_wiki=False, not_print=False): | |
# Obtain input | |
if direct_wiki==False: | |
text = extract_text_from_url(url) | |
filtered_text = filter_paragraphs_by_topic(text, topic) | |
elif direct_wiki==True: | |
display_title = url.split("/")[-1] | |
if not_print == False: | |
print(display_title) | |
page = wiki_wiki.page(display_title) | |
plain_text_content = page.text | |
filtered_text = filter_paragraphs_by_topic(plain_text_content, topic) | |
text_label = "Text: " | |
answer_label = "Answer: " | |
# Now these texts are inputs, and supports are outputs | |
merged_column_input = f"{answer_label} {topic} {text_label} {filtered_text}" | |
# Tokenize the filtered text. Note that large has the topic as an input as well! | |
inputs = tokenizer_large(merged_column_input, return_tensors="pt", max_length=1024, truncation=True).to(device) | |
input_ids = inputs["input_ids"] | |
# Generate outputs | |
outputs = model_WEB_large.generate(input_ids=input_ids, max_length=1024) | |
# Decode the generated outputs | |
generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True) | |
# Print the generated text | |
if not_print == False: | |
print(f"Output of url to support generator: {generated_text}.\n") | |
# Extract the support | |
pattern = r'Support:.*$' | |
# Use regular expression to find the question | |
matcher = re.search(pattern, generated_text) | |
if matcher: | |
support = matcher.group(0) | |
support.strip() | |
if not_print == False: | |
print(f"Found support in output: {support}.\n") # Remove leading/trailing whitespace | |
return support | |
else: | |
if not_print == False: | |
print("No support found in the output string. The next model will be fed with all generated text.\n This might cause strange results.") | |
return generated_text | |
def extract_text_from_url(url): | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.content, 'html.parser') | |
# Extract text from all paragraph tags | |
paragraphs = [p.get_text() for p in soup.find_all('p')] | |
return ' '.join(paragraphs) # Concatenate paragraphs into a single text | |
else: | |
print("Failed to fetch URL:", url) | |
sys.stdout.flush() | |
return None | |
except Exception as e: | |
print("Error occurred while fetching URL:", str(e)) | |
sys.stdout.flush() | |
return None | |
def filter_paragraphs_by_topic(text, topic): | |
relevant_paragraphs = [] | |
try: | |
# Split the text into paragraphs based on newline characters | |
paragraphs = text.split('\n') | |
# Simple, simply keep the paragraph if it contains the topic | |
for paragraph in paragraphs: | |
if topic.lower() in paragraph.lower(): | |
relevant_paragraphs.append(paragraph) | |
# However, simple doesn't always work, so | |
if len(relevant_paragraphs) == 0: | |
topics = topic.split(' ') | |
for split_topic in topics: | |
for paragraph in paragraphs: | |
if split_topic.lower() in paragraph.lower() and paragraph not in relevant_paragraphs: | |
relevant_paragraphs.append(paragraph) | |
except: | |
print("No text was fetched for this topic, as no page was found, so we cannot filter it.") | |
sys.stdout.flush() | |
return None | |
return ' '.join(relevant_paragraphs) | |
def generate_QA_from_url(url, topic, direct_wiki=False, not_print_support=True): | |
return generate_QA(generate_support(url, topic, direct_wiki=direct_wiki, not_print=not_print_support)) | |
def prompt(url, topic): | |
question, answer, distractors = generate_QA_from_url(url, topic) | |
distractors.append(answer) | |
random.shuffle(distractors) | |
return question, answer, distractors[0], distractors[1], distractors[2], distractors[3] | |
with gr.Blocks() as demo: | |
answer = gr.State('') | |
d1 = gr.State('') | |
d2 = gr.State('') | |
d3 = gr.State('') | |
d4 = gr.State('') | |
url = gr.Text(label="URL") | |
topic = gr.Text(label="Topic") | |
submit = gr.Button("Submit") | |
question = gr.Textbox(label="question") | |
choice1 = gr.Button(value="choice1", interactive=True, visible=False) | |
choice2 = gr.Button(value="choice2", interactive=True, visible=False) | |
choice3 = gr.Button(value="choice3", interactive=True, visible=False) | |
choice4 = gr.Button(value="choice4", interactive=True, visible=False) | |
res = gr.Text(value="correct", label="Results", interactive=True, visible=False) | |
def check(correct_answer, answer): | |
if answer == correct_answer: | |
return gr.Text(value='correct', label="result", interactive=True, visible=True) | |
else: | |
return gr.Text(value='wrong', label="result", interactive=True, visible=True) | |
def on_submit(url, topic): | |
question_text, answer, choice1_text, choice2_text, choice3_text, choice4_text = prompt(url, topic) | |
choice1 = gr.Button(value=choice1_text, interactive=True, visible=True) | |
choice2 = gr.Button(value=choice2_text, interactive=True, visible=True) | |
choice3 = gr.Button(value=choice3_text, interactive=True, visible=True) | |
choice4 = gr.Button(value=choice4_text, interactive=True, visible=True) | |
return question_text, answer, choice1, choice2, choice3, choice4, choice1_text, choice2_text, choice3_text, choice4_text | |
gr.on( | |
triggers=[submit.click], | |
fn=on_submit, | |
inputs=[url, topic], | |
outputs=[question, answer, choice1, choice2, choice3, choice4, d1, d2, d3, d4] | |
) | |
choice1.click(check, inputs=[answer, d1], outputs=res) | |
choice2.click(check, inputs=[answer, d2], outputs=res) | |
choice3.click(check, inputs=[answer, d3], outputs=res) | |
choice4.click(check, inputs=[answer, d4], outputs=res) | |
demo.launch(share=True) |