challenge-NLP / app.py
rizkiduwinanto's picture
add app.py
e20c0ed
import requests
from bs4 import BeautifulSoup
import sys
from transformers import BartForConditionalGeneration, BartTokenizer
import re
import torch
import wikipediaapi
import gradio as gr
import random
# User key for the Wikipedia API
wiki_wiki = wikipediaapi.Wikipedia('MCQ Generation (r.j.a.lemein@student.rug.nl)', 'en')
# Load the model from the saved path
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_WEB_large = BartForConditionalGeneration.from_pretrained('rizkiduwinanto/WEB_large').to(device)
model_name_large = "facebook/bart-large"
model_QA_large = BartForConditionalGeneration.from_pretrained("b-b-brouwer/CL_large").to(device)
tokenizer_large = BartTokenizer.from_pretrained(model_name_large)
def generate_QA(input_text):
# Tokenize the input text
inputs = tokenizer_large(input_text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]
# Generate outputs
outputs = model_QA_large.generate(input_ids=input_ids, max_length=1024)
# Decode the generated outputs
generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)
question_pattern = r"Question: (.+?\?)"
answer_pattern = r"Answer: (.+?)(?= Distractor\d+: |$)"
dis1_pattern = r"Distractor1: (.+)"
dis2_pattern = r"Distractor2: (.+)"
dis3_pattern = r"Distractor3: (.+)"
question = ""
answer = ""
question_match = re.search(question_pattern, generated_text)
answer_match = re.search(answer_pattern, generated_text)
if question_match:
question = question_match.group(1)
if answer_match:
answer = answer_match.group(1)
distractor_pattern = r"Distractor\d+: (.+?)(?= Distractor|$)"
distractors = re.findall(distractor_pattern, generated_text)
return question, answer, distractors
def generate_support(url, topic, direct_wiki=False, not_print=False):
# Obtain input
if direct_wiki==False:
text = extract_text_from_url(url)
filtered_text = filter_paragraphs_by_topic(text, topic)
elif direct_wiki==True:
display_title = url.split("/")[-1]
if not_print == False:
print(display_title)
page = wiki_wiki.page(display_title)
plain_text_content = page.text
filtered_text = filter_paragraphs_by_topic(plain_text_content, topic)
text_label = "Text: "
answer_label = "Answer: "
# Now these texts are inputs, and supports are outputs
merged_column_input = f"{answer_label} {topic} {text_label} {filtered_text}"
# Tokenize the filtered text. Note that large has the topic as an input as well!
inputs = tokenizer_large(merged_column_input, return_tensors="pt", max_length=1024, truncation=True).to(device)
input_ids = inputs["input_ids"]
# Generate outputs
outputs = model_WEB_large.generate(input_ids=input_ids, max_length=1024)
# Decode the generated outputs
generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)
# Print the generated text
if not_print == False:
print(f"Output of url to support generator: {generated_text}.\n")
# Extract the support
pattern = r'Support:.*$'
# Use regular expression to find the question
matcher = re.search(pattern, generated_text)
if matcher:
support = matcher.group(0)
support.strip()
if not_print == False:
print(f"Found support in output: {support}.\n") # Remove leading/trailing whitespace
return support
else:
if not_print == False:
print("No support found in the output string. The next model will be fed with all generated text.\n This might cause strange results.")
return generated_text
def extract_text_from_url(url):
try:
response = requests.get(url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
# Extract text from all paragraph tags
paragraphs = [p.get_text() for p in soup.find_all('p')]
return ' '.join(paragraphs) # Concatenate paragraphs into a single text
else:
print("Failed to fetch URL:", url)
sys.stdout.flush()
return None
except Exception as e:
print("Error occurred while fetching URL:", str(e))
sys.stdout.flush()
return None
def filter_paragraphs_by_topic(text, topic):
relevant_paragraphs = []
try:
# Split the text into paragraphs based on newline characters
paragraphs = text.split('\n')
# Simple, simply keep the paragraph if it contains the topic
for paragraph in paragraphs:
if topic.lower() in paragraph.lower():
relevant_paragraphs.append(paragraph)
# However, simple doesn't always work, so
if len(relevant_paragraphs) == 0:
topics = topic.split(' ')
for split_topic in topics:
for paragraph in paragraphs:
if split_topic.lower() in paragraph.lower() and paragraph not in relevant_paragraphs:
relevant_paragraphs.append(paragraph)
except:
print("No text was fetched for this topic, as no page was found, so we cannot filter it.")
sys.stdout.flush()
return None
return ' '.join(relevant_paragraphs)
def generate_QA_from_url(url, topic, direct_wiki=False, not_print_support=True):
return generate_QA(generate_support(url, topic, direct_wiki=direct_wiki, not_print=not_print_support))
def prompt(url, topic):
question, answer, distractors = generate_QA_from_url(url, topic)
distractors.append(answer)
random.shuffle(distractors)
return question, answer, distractors[0], distractors[1], distractors[2], distractors[3]
with gr.Blocks() as demo:
answer = gr.State('')
d1 = gr.State('')
d2 = gr.State('')
d3 = gr.State('')
d4 = gr.State('')
url = gr.Text(label="URL")
topic = gr.Text(label="Topic")
submit = gr.Button("Submit")
question = gr.Textbox(label="question")
choice1 = gr.Button(value="choice1", interactive=True, visible=False)
choice2 = gr.Button(value="choice2", interactive=True, visible=False)
choice3 = gr.Button(value="choice3", interactive=True, visible=False)
choice4 = gr.Button(value="choice4", interactive=True, visible=False)
res = gr.Text(value="correct", label="Results", interactive=True, visible=False)
def check(correct_answer, answer):
if answer == correct_answer:
return gr.Text(value='correct', label="result", interactive=True, visible=True)
else:
return gr.Text(value='wrong', label="result", interactive=True, visible=True)
def on_submit(url, topic):
question_text, answer, choice1_text, choice2_text, choice3_text, choice4_text = prompt(url, topic)
choice1 = gr.Button(value=choice1_text, interactive=True, visible=True)
choice2 = gr.Button(value=choice2_text, interactive=True, visible=True)
choice3 = gr.Button(value=choice3_text, interactive=True, visible=True)
choice4 = gr.Button(value=choice4_text, interactive=True, visible=True)
return question_text, answer, choice1, choice2, choice3, choice4, choice1_text, choice2_text, choice3_text, choice4_text
gr.on(
triggers=[submit.click],
fn=on_submit,
inputs=[url, topic],
outputs=[question, answer, choice1, choice2, choice3, choice4, d1, d2, d3, d4]
)
choice1.click(check, inputs=[answer, d1], outputs=res)
choice2.click(check, inputs=[answer, d2], outputs=res)
choice3.click(check, inputs=[answer, d3], outputs=res)
choice4.click(check, inputs=[answer, d4], outputs=res)
demo.launch(share=True)