EdNA-STEM-MCQ / app.py
HAissa's picture
Update app.py
f57a822 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Model Loading ---
model_id = "HAissa/EdNA"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
def get_answer(question, ans_a, ans_b, ans_c, ans_d):
options = f"A) {ans_a}\nB) {ans_b}\nC) {ans_c}\nD) {ans_d}"
prompt = f"Question: {question}\nOptions:\n{options}\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate the answer
outputs = model.generate(**inputs, max_new_tokens=3)
answer_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
try:
final_answer = answer_text.split("Answer:")[1].strip().split('\n')[0]
if final_answer.startswith("A)"):
return ans_a
elif final_answer.startswith("B)"):
return ans_b
elif final_answer.startswith("C)"):
return ans_c
elif final_answer.startswith("D)"):
return ans_d
else:
return final_answer
except IndexError:
final_answer = "Could not parse the model's answer."
return final_answer
# --- Gradio Interface ---
# Use a modern font (Poppins) and 'emerald' for a fresh green look.
# Increased radius_size gives components a friendlier, modern rounded look.
theme = gr.themes.Soft(
primary_hue="emerald",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
radius_size="lg"
)
with gr.Blocks(theme=theme) as demo:
gr.Markdown(
"""
# πŸ€– EdNA: MCQ Answering AI
Enter your question and options below, then click **Predict Answer** to see the model's choice.
"""
)
with gr.Group():
# Question input: larger, but limited max_lines to prevent excessive scrolling
question_input = gr.Textbox(
label="Question",
placeholder="Type the full question here...",
lines=3,
max_lines=5
)
# 2x2 Grid for a compact, modern MCQ layout
with gr.Row():
# setting max_lines=1 ensures these stay as single-line, non-scrollable input fields
answer_a_input = gr.Textbox(label = "", placeholder="Answer A", lines=1, max_lines=1)
answer_b_input = gr.Textbox(label = "", placeholder="Answer B", lines=1, max_lines=1)
with gr.Row():
answer_c_input = gr.Textbox(label = "", placeholder="Answer C", lines=1, max_lines=1)
answer_d_input = gr.Textbox(label = "", placeholder="Answer D", lines=1, max_lines=1)
# A larger, more prominent button
get_answer_button = gr.Button("✨ Predict Answer", variant="primary", size="lg")
# Distinct output box
final_answer_output = gr.Textbox(
label="Model Prediction",
interactive=False,
lines=2,
placeholder="The result will appear here..."
)
get_answer_button.click(
fn=get_answer,
inputs=[question_input, answer_a_input, answer_b_input, answer_c_input, answer_d_input],
outputs=final_answer_output
)
if __name__ == "__main__":
demo.launch()