|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
tokenizer1 = AutoTokenizer.from_pretrained("Reham721/Subjective_QG") |
|
tokenizer2 = AutoTokenizer.from_pretrained("Reham721/MCQs") |
|
|
|
from transformers import AutoModelForSeq2SeqLM |
|
|
|
model1 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/Subjective_QG") |
|
model2 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/MCQs") |
|
|
|
|
|
from arabert.preprocess import ArabertPreprocessor |
|
from transformers import pipeline |
|
|
|
prep = ArabertPreprocessor("aubmindlab/araelectra-base-discriminator") |
|
qa_pipe =pipeline("question-answering",model="wissamantoun/araelectra-base-artydiqa") |
|
|
|
def generate_questions(model,tokenizer, input_sequence): |
|
|
|
|
|
input_ids = tokenizer.encode(input_sequence, return_tensors='pt') |
|
|
|
|
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_length=200, |
|
num_beams=3, |
|
no_repeat_ngram_size=3, |
|
early_stopping=True, |
|
temperature=1, |
|
num_return_sequences=3, |
|
) |
|
|
|
|
|
questions = [] |
|
for output in outputs: |
|
output_text = tokenizer.decode(output, skip_special_tokens=True) |
|
questions.append(output_text) |
|
|
|
return questions |
|
|
|
def get_sorted_questions(questions, context): |
|
dic = {} |
|
context = prep.preprocess(context) |
|
for question in questions: |
|
print(question) |
|
result = qa_pipe(question=question,context=context) |
|
print(result) |
|
dic.update({question: result["score"]}) |
|
|
|
return dict(sorted(dic.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
|
|
|
|
import unicodedata |
|
import arabic_reshaper |
|
from bidi.algorithm import get_display |
|
|
|
def is_arabic(text): |
|
|
|
reshaped_text = arabic_reshaper.reshape(text) |
|
|
|
bidi_text = get_display(reshaped_text) |
|
|
|
for char in bidi_text: |
|
if char.isalpha() and unicodedata.name(char).startswith('ARABIC') == False: |
|
return False |
|
return True |
|
|
|
import random |
|
import re |
|
def generate_distractors(question, answer, context, num_distractors=3, k=10): |
|
input_sequence = f'{question} <sep> {answer} <sep> {context}' |
|
input_ids = tokenizer2.encode(input_sequence, return_tensors='pt') |
|
|
|
|
|
outputs = model2.generate( |
|
input_ids, |
|
do_sample=True, max_length=50, top_k=50, top_p=0.95, num_return_sequences=num_distractors, no_repeat_ngram_size=2) |
|
|
|
|
|
distractors = [] |
|
for output in outputs: |
|
decoded_output = tokenizer2.decode(output, skip_special_tokens=True) |
|
distractor_elements = [re.sub(r'<[^>]*>', '', element.strip()) for element in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if element] |
|
distractor_elements = [element for element in distractor_elements if element] |
|
distractor_elements = [element for element in distractor_elements if is_arabic(element)] |
|
distractors.append(distractor_elements) |
|
distractors = [element for sublist in distractors for element in sublist] |
|
|
|
|
|
|
|
unique_distractors = [] |
|
for distractor in distractors: |
|
if distractor not in unique_distractors and distractor != answer: |
|
unique_distractors.append(distractor) |
|
|
|
|
|
|
|
while len(unique_distractors) < num_distractors: |
|
|
|
outputs = model2.generate( |
|
input_ids, |
|
do_sample=True, |
|
max_length=50, |
|
top_k=50, |
|
top_p=0.95, |
|
num_return_sequences=num_distractors-len(unique_distractors), |
|
no_repeat_ngram_size=2) |
|
for output in outputs: |
|
decoded_output = tokenizer2.decode(output, skip_special_tokens=True) |
|
distractor_elements = [re.sub(r'<[^>]*>', '', element.strip()) for element in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if element] |
|
distractor_elements = [element for element in distractor_elements if element] |
|
distractor_elements = [element for element in distractor_elements if is_arabic(element)] |
|
if decoded_output not in unique_distractors and decoded_output not in unique_distractors and decoded_output != answer: |
|
unique_distractors.append(decoded_output) |
|
if len(unique_distractors) >= num_distractors: |
|
break |
|
|
|
random.shuffle(unique_distractors) |
|
|
|
|
|
if len(unique_distractors) > k: |
|
unique_distractors = sorted(unique_distractors, key=lambda x: random.random())[:k] |
|
|
|
|
|
distractor_subset = random.sample(unique_distractors, num_distractors) |
|
|
|
return distractor_subset |
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
context = gr.inputs.Textbox(lines=5,placeholder="Enter paragraph/context here...") |
|
answer = gr.inputs.Textbox(lines=3, placeholder="Enter answer/keyword here...") |
|
question_type = gr.inputs.Radio(choices=["Subjective", "MCQ"], label="Question type") |
|
question = gr.outputs.Textbox( type="text", label="Question") |
|
|
|
def generate_question(context,answer,question_type): |
|
article = answer+"<sep>"+context |
|
output = generate_questions(model1, tokenizer1, article) |
|
result = get_sorted_questions(output, context) |
|
if question_type == "Subjective": |
|
return next(iter(result)) |
|
else: |
|
mcqs = generate_distractors(question, answer, context) |
|
mcqs[3] = answer |
|
|
|
return next(iter(result))+"\n"+"-" + mcqs[0]+"\n"+"-" + mcqs[1]+"\n"+"-" + mcqs[2]+"\n" +"-" + mcqs[3] +"\n" |
|
|
|
iface = gr.Interface( |
|
fn=generate_question, |
|
inputs=[context,answer,question_type], |
|
outputs=question, |
|
list_outputs=True, |
|
rtl=True) |
|
|
|
iface.launch(debug=True,share=False) |
|
|