challenge-NLP / app.py
rizkiduwinanto's picture
add app.py
e20c0ed
raw
history blame contribute delete
No virus
7.93 kB
import requests
from bs4 import BeautifulSoup
import sys
from transformers import BartForConditionalGeneration, BartTokenizer
import re
import torch
import wikipediaapi
import gradio as gr
import random
# User key for the Wikipedia API
wiki_wiki = wikipediaapi.Wikipedia('MCQ Generation (r.j.a.lemein@student.rug.nl)', 'en')
# Load the model from the saved path
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_WEB_large = BartForConditionalGeneration.from_pretrained('rizkiduwinanto/WEB_large').to(device)
model_name_large = "facebook/bart-large"
model_QA_large = BartForConditionalGeneration.from_pretrained("b-b-brouwer/CL_large").to(device)
tokenizer_large = BartTokenizer.from_pretrained(model_name_large)
def generate_QA(input_text):
# Tokenize the input text
inputs = tokenizer_large(input_text, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]
# Generate outputs
outputs = model_QA_large.generate(input_ids=input_ids, max_length=1024)
# Decode the generated outputs
generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)
question_pattern = r"Question: (.+?\?)"
answer_pattern = r"Answer: (.+?)(?= Distractor\d+: |$)"
dis1_pattern = r"Distractor1: (.+)"
dis2_pattern = r"Distractor2: (.+)"
dis3_pattern = r"Distractor3: (.+)"
question = ""
answer = ""
question_match = re.search(question_pattern, generated_text)
answer_match = re.search(answer_pattern, generated_text)
if question_match:
question = question_match.group(1)
if answer_match:
answer = answer_match.group(1)
distractor_pattern = r"Distractor\d+: (.+?)(?= Distractor|$)"
distractors = re.findall(distractor_pattern, generated_text)
return question, answer, distractors
def generate_support(url, topic, direct_wiki=False, not_print=False):
# Obtain input
if direct_wiki==False:
text = extract_text_from_url(url)
filtered_text = filter_paragraphs_by_topic(text, topic)
elif direct_wiki==True:
display_title = url.split("/")[-1]
if not_print == False:
print(display_title)
page = wiki_wiki.page(display_title)
plain_text_content = page.text
filtered_text = filter_paragraphs_by_topic(plain_text_content, topic)
text_label = "Text: "
answer_label = "Answer: "
# Now these texts are inputs, and supports are outputs
merged_column_input = f"{answer_label} {topic} {text_label} {filtered_text}"
# Tokenize the filtered text. Note that large has the topic as an input as well!
inputs = tokenizer_large(merged_column_input, return_tensors="pt", max_length=1024, truncation=True).to(device)
input_ids = inputs["input_ids"]
# Generate outputs
outputs = model_WEB_large.generate(input_ids=input_ids, max_length=1024)
# Decode the generated outputs
generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)
# Print the generated text
if not_print == False:
print(f"Output of url to support generator: {generated_text}.\n")
# Extract the support
pattern = r'Support:.*$'
# Use regular expression to find the question
matcher = re.search(pattern, generated_text)
if matcher:
support = matcher.group(0)
support.strip()
if not_print == False:
print(f"Found support in output: {support}.\n") # Remove leading/trailing whitespace
return support
else:
if not_print == False:
print("No support found in the output string. The next model will be fed with all generated text.\n This might cause strange results.")
return generated_text
def extract_text_from_url(url):
try:
response = requests.get(url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, 'html.parser')
# Extract text from all paragraph tags
paragraphs = [p.get_text() for p in soup.find_all('p')]
return ' '.join(paragraphs) # Concatenate paragraphs into a single text
else:
print("Failed to fetch URL:", url)
sys.stdout.flush()
return None
except Exception as e:
print("Error occurred while fetching URL:", str(e))
sys.stdout.flush()
return None
def filter_paragraphs_by_topic(text, topic):
relevant_paragraphs = []
try:
# Split the text into paragraphs based on newline characters
paragraphs = text.split('\n')
# Simple, simply keep the paragraph if it contains the topic
for paragraph in paragraphs:
if topic.lower() in paragraph.lower():
relevant_paragraphs.append(paragraph)
# However, simple doesn't always work, so
if len(relevant_paragraphs) == 0:
topics = topic.split(' ')
for split_topic in topics:
for paragraph in paragraphs:
if split_topic.lower() in paragraph.lower() and paragraph not in relevant_paragraphs:
relevant_paragraphs.append(paragraph)
except:
print("No text was fetched for this topic, as no page was found, so we cannot filter it.")
sys.stdout.flush()
return None
return ' '.join(relevant_paragraphs)
def generate_QA_from_url(url, topic, direct_wiki=False, not_print_support=True):
return generate_QA(generate_support(url, topic, direct_wiki=direct_wiki, not_print=not_print_support))
def prompt(url, topic):
question, answer, distractors = generate_QA_from_url(url, topic)
distractors.append(answer)
random.shuffle(distractors)
return question, answer, distractors[0], distractors[1], distractors[2], distractors[3]
with gr.Blocks() as demo:
answer = gr.State('')
d1 = gr.State('')
d2 = gr.State('')
d3 = gr.State('')
d4 = gr.State('')
url = gr.Text(label="URL")
topic = gr.Text(label="Topic")
submit = gr.Button("Submit")
question = gr.Textbox(label="question")
choice1 = gr.Button(value="choice1", interactive=True, visible=False)
choice2 = gr.Button(value="choice2", interactive=True, visible=False)
choice3 = gr.Button(value="choice3", interactive=True, visible=False)
choice4 = gr.Button(value="choice4", interactive=True, visible=False)
res = gr.Text(value="correct", label="Results", interactive=True, visible=False)
def check(correct_answer, answer):
if answer == correct_answer:
return gr.Text(value='correct', label="result", interactive=True, visible=True)
else:
return gr.Text(value='wrong', label="result", interactive=True, visible=True)
def on_submit(url, topic):
question_text, answer, choice1_text, choice2_text, choice3_text, choice4_text = prompt(url, topic)
choice1 = gr.Button(value=choice1_text, interactive=True, visible=True)
choice2 = gr.Button(value=choice2_text, interactive=True, visible=True)
choice3 = gr.Button(value=choice3_text, interactive=True, visible=True)
choice4 = gr.Button(value=choice4_text, interactive=True, visible=True)
return question_text, answer, choice1, choice2, choice3, choice4, choice1_text, choice2_text, choice3_text, choice4_text
gr.on(
triggers=[submit.click],
fn=on_submit,
inputs=[url, topic],
outputs=[question, answer, choice1, choice2, choice3, choice4, d1, d2, d3, d4]
)
choice1.click(check, inputs=[answer, d1], outputs=res)
choice2.click(check, inputs=[answer, d2], outputs=res)
choice3.click(check, inputs=[answer, d3], outputs=res)
choice4.click(check, inputs=[answer, d4], outputs=res)
demo.launch(share=True)