File size: 7,929 Bytes
e20c0ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import requests
from bs4 import BeautifulSoup
import sys
from transformers import BartForConditionalGeneration, BartTokenizer
import re
import torch
import wikipediaapi
import gradio as gr
import random

# User key for the Wikipedia API
wiki_wiki = wikipediaapi.Wikipedia('MCQ Generation (r.j.a.lemein@student.rug.nl)', 'en')

# Load the model from the saved path
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_WEB_large = BartForConditionalGeneration.from_pretrained('rizkiduwinanto/WEB_large').to(device)
model_name_large = "facebook/bart-large"
model_QA_large = BartForConditionalGeneration.from_pretrained("b-b-brouwer/CL_large").to(device)
tokenizer_large = BartTokenizer.from_pretrained(model_name_large)

def generate_QA(input_text):
    # Tokenize the input text
    inputs = tokenizer_large(input_text, return_tensors="pt").to(device) 
    input_ids = inputs["input_ids"]
    
    # Generate outputs
    outputs = model_QA_large.generate(input_ids=input_ids, max_length=1024)

    # Decode the generated outputs
    generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)

    question_pattern = r"Question: (.+?\?)"
    answer_pattern = r"Answer: (.+?)(?= Distractor\d+: |$)"
    dis1_pattern = r"Distractor1: (.+)"
    dis2_pattern = r"Distractor2: (.+)"
    dis3_pattern = r"Distractor3: (.+)"

    question = ""
    answer = ""
    
    question_match = re.search(question_pattern, generated_text)
    answer_match = re.search(answer_pattern, generated_text)

    if question_match:
        question = question_match.group(1)
    if answer_match:
        answer = answer_match.group(1)

    distractor_pattern = r"Distractor\d+: (.+?)(?= Distractor|$)"
    distractors = re.findall(distractor_pattern, generated_text)
    
    return question, answer, distractors

def generate_support(url, topic, direct_wiki=False, not_print=False):
    # Obtain input
    if direct_wiki==False:
        text = extract_text_from_url(url)
        filtered_text = filter_paragraphs_by_topic(text, topic)
    elif direct_wiki==True:
        display_title = url.split("/")[-1]
        if not_print == False:
            print(display_title)
        page = wiki_wiki.page(display_title)
        plain_text_content = page.text
        filtered_text = filter_paragraphs_by_topic(plain_text_content, topic)
    
    text_label = "Text: "
    answer_label = "Answer: "
    
    # Now these texts are inputs, and supports are outputs
    merged_column_input = f"{answer_label} {topic} {text_label} {filtered_text}"

    # Tokenize the filtered text. Note that large has the topic as an input as well!
    inputs = tokenizer_large(merged_column_input, return_tensors="pt", max_length=1024, truncation=True).to(device) 
    input_ids = inputs["input_ids"]
    
    # Generate outputs
    outputs = model_WEB_large.generate(input_ids=input_ids, max_length=1024)

    # Decode the generated outputs
    generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)

    # Print the generated text
    if not_print == False:
        print(f"Output of url to support generator: {generated_text}.\n")
    
    # Extract the support
    pattern = r'Support:.*$'
    
    # Use regular expression to find the question
    matcher = re.search(pattern, generated_text)
    
    if matcher:
        support = matcher.group(0)
        support.strip()
        if not_print == False:
            print(f"Found support in output: {support}.\n")  # Remove leading/trailing whitespace
        return support
    else:
        if not_print == False:
            print("No support found in the output string. The next model will be fed with all generated text.\n This might cause strange results.")
        return generated_text


def extract_text_from_url(url):
    try:
        response = requests.get(url)
        if response.status_code == 200:
            soup = BeautifulSoup(response.content, 'html.parser')
            # Extract text from all paragraph tags
            paragraphs = [p.get_text() for p in soup.find_all('p')]
            return ' '.join(paragraphs)  # Concatenate paragraphs into a single text
        else:
            print("Failed to fetch URL:", url)
            sys.stdout.flush()            
            return None
    except Exception as e:
        print("Error occurred while fetching URL:", str(e))
        sys.stdout.flush()
        return None


def filter_paragraphs_by_topic(text, topic):
    relevant_paragraphs = []
    try:
        # Split the text into paragraphs based on newline characters
        paragraphs = text.split('\n')
        
        # Simple, simply keep the paragraph if it contains the topic
        for paragraph in paragraphs:
            if topic.lower() in paragraph.lower():
                relevant_paragraphs.append(paragraph)
                
        # However, simple doesn't always work, so
        if len(relevant_paragraphs) == 0:
            topics = topic.split(' ')
            for split_topic in topics:
                for paragraph in paragraphs:
                    if split_topic.lower() in paragraph.lower() and paragraph not in relevant_paragraphs:
                        relevant_paragraphs.append(paragraph)
                    
    except:
        print("No text was fetched for this topic, as no page was found, so we cannot filter it.")
        sys.stdout.flush()
        return None
        
                
    return ' '.join(relevant_paragraphs)


def generate_QA_from_url(url, topic, direct_wiki=False, not_print_support=True):
    return generate_QA(generate_support(url, topic, direct_wiki=direct_wiki, not_print=not_print_support))

def prompt(url, topic):
    question, answer, distractors = generate_QA_from_url(url, topic)
    distractors.append(answer)
    random.shuffle(distractors)
    return question, answer, distractors[0], distractors[1], distractors[2], distractors[3]

with gr.Blocks() as demo:
    answer = gr.State('')
    d1 = gr.State('')
    d2 = gr.State('')
    d3 = gr.State('')
    d4 = gr.State('')
    
    url = gr.Text(label="URL")
    topic = gr.Text(label="Topic")
    submit = gr.Button("Submit")
    
    question = gr.Textbox(label="question")
    
    choice1 = gr.Button(value="choice1", interactive=True, visible=False)
    choice2 = gr.Button(value="choice2", interactive=True, visible=False)
    choice3 = gr.Button(value="choice3", interactive=True, visible=False)
    choice4 = gr.Button(value="choice4", interactive=True, visible=False)

    res = gr.Text(value="correct", label="Results", interactive=True, visible=False)

    def check(correct_answer, answer):
        if answer == correct_answer:
            return gr.Text(value='correct', label="result", interactive=True, visible=True)
        else:
            return gr.Text(value='wrong', label="result", interactive=True, visible=True)

    def on_submit(url, topic):
        question_text, answer, choice1_text, choice2_text, choice3_text, choice4_text = prompt(url, topic)

        choice1 = gr.Button(value=choice1_text, interactive=True, visible=True)
        choice2 = gr.Button(value=choice2_text, interactive=True, visible=True)
        choice3 = gr.Button(value=choice3_text, interactive=True, visible=True)
        choice4 = gr.Button(value=choice4_text, interactive=True, visible=True)

        return question_text, answer, choice1, choice2, choice3, choice4, choice1_text, choice2_text, choice3_text, choice4_text
    
    gr.on(
        triggers=[submit.click],
        fn=on_submit,
        inputs=[url, topic],
        outputs=[question, answer, choice1, choice2, choice3, choice4, d1, d2, d3, d4]
    )

    choice1.click(check, inputs=[answer, d1], outputs=res)
    choice2.click(check, inputs=[answer, d2], outputs=res)
    choice3.click(check, inputs=[answer, d3], outputs=res)
    choice4.click(check, inputs=[answer, d4], outputs=res)
    
demo.launch(share=True)