import requests from bs4 import BeautifulSoup import sys from transformers import BartForConditionalGeneration, BartTokenizer import re import torch import wikipediaapi import gradio as gr import random # User key for the Wikipedia API wiki_wiki = wikipediaapi.Wikipedia('MCQ Generation (r.j.a.lemein@student.rug.nl)', 'en') # Load the model from the saved path device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model_WEB_large = BartForConditionalGeneration.from_pretrained('rizkiduwinanto/WEB_large').to(device) model_name_large = "facebook/bart-large" model_QA_large = BartForConditionalGeneration.from_pretrained("b-b-brouwer/CL_large").to(device) tokenizer_large = BartTokenizer.from_pretrained(model_name_large) def generate_QA(input_text): # Tokenize the input text inputs = tokenizer_large(input_text, return_tensors="pt").to(device) input_ids = inputs["input_ids"] # Generate outputs outputs = model_QA_large.generate(input_ids=input_ids, max_length=1024) # Decode the generated outputs generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True) question_pattern = r"Question: (.+?\?)" answer_pattern = r"Answer: (.+?)(?= Distractor\d+: |$)" dis1_pattern = r"Distractor1: (.+)" dis2_pattern = r"Distractor2: (.+)" dis3_pattern = r"Distractor3: (.+)" question = "" answer = "" question_match = re.search(question_pattern, generated_text) answer_match = re.search(answer_pattern, generated_text) if question_match: question = question_match.group(1) if answer_match: answer = answer_match.group(1) distractor_pattern = r"Distractor\d+: (.+?)(?= Distractor|$)" distractors = re.findall(distractor_pattern, generated_text) return question, answer, distractors def generate_support(url, topic, direct_wiki=False, not_print=False): # Obtain input if direct_wiki==False: text = extract_text_from_url(url) filtered_text = filter_paragraphs_by_topic(text, topic) elif direct_wiki==True: display_title = url.split("/")[-1] if not_print == False: print(display_title) page = wiki_wiki.page(display_title) plain_text_content = page.text filtered_text = filter_paragraphs_by_topic(plain_text_content, topic) text_label = "Text: " answer_label = "Answer: " # Now these texts are inputs, and supports are outputs merged_column_input = f"{answer_label} {topic} {text_label} {filtered_text}" # Tokenize the filtered text. Note that large has the topic as an input as well! inputs = tokenizer_large(merged_column_input, return_tensors="pt", max_length=1024, truncation=True).to(device) input_ids = inputs["input_ids"] # Generate outputs outputs = model_WEB_large.generate(input_ids=input_ids, max_length=1024) # Decode the generated outputs generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True) # Print the generated text if not_print == False: print(f"Output of url to support generator: {generated_text}.\n") # Extract the support pattern = r'Support:.*$' # Use regular expression to find the question matcher = re.search(pattern, generated_text) if matcher: support = matcher.group(0) support.strip() if not_print == False: print(f"Found support in output: {support}.\n") # Remove leading/trailing whitespace return support else: if not_print == False: print("No support found in the output string. The next model will be fed with all generated text.\n This might cause strange results.") return generated_text def extract_text_from_url(url): try: response = requests.get(url) if response.status_code == 200: soup = BeautifulSoup(response.content, 'html.parser') # Extract text from all paragraph tags paragraphs = [p.get_text() for p in soup.find_all('p')] return ' '.join(paragraphs) # Concatenate paragraphs into a single text else: print("Failed to fetch URL:", url) sys.stdout.flush() return None except Exception as e: print("Error occurred while fetching URL:", str(e)) sys.stdout.flush() return None def filter_paragraphs_by_topic(text, topic): relevant_paragraphs = [] try: # Split the text into paragraphs based on newline characters paragraphs = text.split('\n') # Simple, simply keep the paragraph if it contains the topic for paragraph in paragraphs: if topic.lower() in paragraph.lower(): relevant_paragraphs.append(paragraph) # However, simple doesn't always work, so if len(relevant_paragraphs) == 0: topics = topic.split(' ') for split_topic in topics: for paragraph in paragraphs: if split_topic.lower() in paragraph.lower() and paragraph not in relevant_paragraphs: relevant_paragraphs.append(paragraph) except: print("No text was fetched for this topic, as no page was found, so we cannot filter it.") sys.stdout.flush() return None return ' '.join(relevant_paragraphs) def generate_QA_from_url(url, topic, direct_wiki=False, not_print_support=True): return generate_QA(generate_support(url, topic, direct_wiki=direct_wiki, not_print=not_print_support)) def prompt(url, topic): question, answer, distractors = generate_QA_from_url(url, topic) distractors.append(answer) random.shuffle(distractors) return question, answer, distractors[0], distractors[1], distractors[2], distractors[3] with gr.Blocks() as demo: answer = gr.State('') d1 = gr.State('') d2 = gr.State('') d3 = gr.State('') d4 = gr.State('') url = gr.Text(label="URL") topic = gr.Text(label="Topic") submit = gr.Button("Submit") question = gr.Textbox(label="question") choice1 = gr.Button(value="choice1", interactive=True, visible=False) choice2 = gr.Button(value="choice2", interactive=True, visible=False) choice3 = gr.Button(value="choice3", interactive=True, visible=False) choice4 = gr.Button(value="choice4", interactive=True, visible=False) res = gr.Text(value="correct", label="Results", interactive=True, visible=False) def check(correct_answer, answer): if answer == correct_answer: return gr.Text(value='correct', label="result", interactive=True, visible=True) else: return gr.Text(value='wrong', label="result", interactive=True, visible=True) def on_submit(url, topic): question_text, answer, choice1_text, choice2_text, choice3_text, choice4_text = prompt(url, topic) choice1 = gr.Button(value=choice1_text, interactive=True, visible=True) choice2 = gr.Button(value=choice2_text, interactive=True, visible=True) choice3 = gr.Button(value=choice3_text, interactive=True, visible=True) choice4 = gr.Button(value=choice4_text, interactive=True, visible=True) return question_text, answer, choice1, choice2, choice3, choice4, choice1_text, choice2_text, choice3_text, choice4_text gr.on( triggers=[submit.click], fn=on_submit, inputs=[url, topic], outputs=[question, answer, choice1, choice2, choice3, choice4, d1, d2, d3, d4] ) choice1.click(check, inputs=[answer, d1], outputs=res) choice2.click(check, inputs=[answer, d2], outputs=res) choice3.click(check, inputs=[answer, d3], outputs=res) choice4.click(check, inputs=[answer, d4], outputs=res) demo.launch(share=True)