QG_Demo / app.py
Reham721's picture
Update app.py
0341a18
raw
history blame
6.47 kB
from transformers import AutoTokenizer
# from huggingface_hub import notebook_login
# notebook_login()
from transformers import AutoTokenizer
tokenizer1 = AutoTokenizer.from_pretrained("Reham721/Subjective_QG")
tokenizer2 = AutoTokenizer.from_pretrained("Reham721/MCQs")
from transformers import AutoModelForSeq2SeqLM
model1 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/Subjective_QG")
model2 = AutoModelForSeq2SeqLM.from_pretrained("Reham721/MCQs")
from arabert.preprocess import ArabertPreprocessor
from transformers import pipeline
prep = ArabertPreprocessor("aubmindlab/araelectra-base-discriminator") #or empty string it's the same
qa_pipe =pipeline("question-answering",model="wissamantoun/araelectra-base-artydiqa")
def generate_questions(model,tokenizer, input_sequence):
# Tokenize input sequence
input_ids = tokenizer.encode(input_sequence, return_tensors='pt')
# Generate questions
outputs = model.generate(
input_ids=input_ids,
max_length=200, # Set a shorter maximum length for shorter questions
num_beams=3, # Use fewer beams for faster generation and to avoid overfitting
no_repeat_ngram_size=3, # Allow some repetition to avoid overly generic questions
early_stopping=True, # Stop generation when all beams are finished
temperature=1, # Use a lower temperature for more conservative questions
num_return_sequences=3, # Generate more questions per input
)
# Decode questions
questions = []
for output in outputs:
output_text = tokenizer.decode(output, skip_special_tokens=True)
questions.append(output_text)
return questions
def get_sorted_questions(questions, context):
dic = {}
context = prep.preprocess(context)
for question in questions:
print(question)
result = qa_pipe(question=question,context=context)
print(result)
dic.update({question: result["score"]})
return dict(sorted(dic.items(), key=lambda item: item[1], reverse=True))
import unicodedata
import arabic_reshaper
from bidi.algorithm import get_display
def is_arabic(text):
# Reshape the text using the arabic_reshaper library
reshaped_text = arabic_reshaper.reshape(text)
# Determine the directionality of the text using the get_display() function from the bidi algorithm library
bidi_text = get_display(reshaped_text)
# Check if the text contains any non-Arabic letters
for char in bidi_text:
if char.isalpha() and unicodedata.name(char).startswith('ARABIC') == False:
return False
return True
import random
import re
def generate_distractors(question, answer, context, num_distractors=3, k=10):
input_sequence = f'{question} <sep> {answer} <sep> {context}'
input_ids = tokenizer2.encode(input_sequence, return_tensors='pt')
# Generate distractors using model.generate()
outputs = model2.generate(
input_ids,
do_sample=True, max_length=50, top_k=50, top_p=0.95, num_return_sequences=num_distractors, no_repeat_ngram_size=2)
# Convert outputs to list of strings
distractors = []
for output in outputs:
decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
distractor_elements = [re.sub(r'<[^>]*>', '', element.strip()) for element in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if element]
distractor_elements = [element for element in distractor_elements if element]
distractor_elements = [element for element in distractor_elements if is_arabic(element)]
distractors.append(distractor_elements)
distractors = [element for sublist in distractors for element in sublist]
# Remove duplicate distractors
unique_distractors = []
for distractor in distractors:
if distractor not in unique_distractors and distractor != answer:
unique_distractors.append(distractor)
# If there are not enough unique distractors, generate more until there are
while len(unique_distractors) < num_distractors:
outputs = model2.generate(
input_ids,
do_sample=True,
max_length=50,
top_k=50,
top_p=0.95,
num_return_sequences=num_distractors-len(unique_distractors),
no_repeat_ngram_size=2)
for output in outputs:
decoded_output = tokenizer2.decode(output, skip_special_tokens=True)
distractor_elements = [re.sub(r'<[^>]*>', '', element.strip()) for element in re.split(r'(<[^>]*>)|(?:None)', decoded_output) if element]
distractor_elements = [element for element in distractor_elements if element]
distractor_elements = [element for element in distractor_elements if is_arabic(element)]
if decoded_output not in unique_distractors and decoded_output not in unique_distractors and decoded_output != answer:
unique_distractors.append(decoded_output)
if len(unique_distractors) >= num_distractors:
break
random.shuffle(unique_distractors)
# Select k top distractors if more than k obtained in step 2
if len(unique_distractors) > k:
unique_distractors = sorted(unique_distractors, key=lambda x: random.random())[:k]
# Select num_distractors distractors
distractor_subset = random.sample(unique_distractors, num_distractors)
return distractor_subset
import gradio as gr
context = gr.inputs.Textbox(lines=5,placeholder="Enter paragraph/context here...")
answer = gr.inputs.Textbox(lines=3, placeholder="Enter answer/keyword here...")
question_type = gr.inputs.Radio(choices=["Subjective", "MCQ"], label="Question type")
question = gr.outputs.Textbox( type="text", label="Question")
def generate_question(context,answer,question_type):
article = answer+"<sep>"+context
output = generate_questions(model1, tokenizer1, article)
result = get_sorted_questions(output, context)
if question_type == "Subjective":
return next(iter(result))
else:
mcqs = generate_distractors(question, answer, context)
mcqs[3] = answer
# random.shuffle(mcqs)
return next(iter(result))+"\n"+"-" + mcqs[0]+"\n"+"-" + mcqs[1]+"\n"+"-" + mcqs[2]+"\n" +"-" + mcqs[3] +"\n"
iface = gr.Interface(
fn=generate_question,
inputs=[context,answer,question_type],
outputs=question,
list_outputs=True,
rtl=True)
iface.launch(debug=True,share=False) # will create a temporary sharable link