Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,17 +2,11 @@ import gradio as gr
|
|
| 2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
|
| 4 |
# --- Model Loading ---
|
| 5 |
-
# It's recommended to load the model and tokenizer once outside the function
|
| 6 |
-
# to avoid reloading them on every call.
|
| 7 |
model_id = "HAissa/EdNA"
|
| 8 |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
|
| 9 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 10 |
|
| 11 |
def get_answer(question, ans_a, ans_b, ans_c, ans_d):
|
| 12 |
-
"""
|
| 13 |
-
This function takes a question and four answers, formats them for the EdNA model,
|
| 14 |
-
and returns the model's predicted answer.
|
| 15 |
-
"""
|
| 16 |
options = f"A) {ans_a}\nB) {ans_b}\nC) {ans_c}\nD) {ans_d}"
|
| 17 |
prompt = f"Question: {question}\nOptions:\n{options}\nAnswer:"
|
| 18 |
|
|
@@ -22,7 +16,6 @@ def get_answer(question, ans_a, ans_b, ans_c, ans_d):
|
|
| 22 |
outputs = model.generate(**inputs, max_new_tokens=3)
|
| 23 |
answer_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 24 |
|
| 25 |
-
# Parse the final answer
|
| 26 |
try:
|
| 27 |
final_answer = answer_text.split("Answer:")[1].strip().split('\n')[0]
|
| 28 |
if final_answer.startswith("A)"):
|
|
@@ -35,7 +28,6 @@ def get_answer(question, ans_a, ans_b, ans_c, ans_d):
|
|
| 35 |
return ans_d
|
| 36 |
else:
|
| 37 |
return final_answer
|
| 38 |
-
|
| 39 |
except IndexError:
|
| 40 |
final_answer = "Could not parse the model's answer."
|
| 41 |
|
|
@@ -43,26 +35,50 @@ def get_answer(question, ans_a, ans_b, ans_c, ans_d):
|
|
| 43 |
|
| 44 |
# --- Gradio Interface ---
|
| 45 |
|
| 46 |
-
#
|
|
|
|
| 47 |
theme = gr.themes.Soft(
|
| 48 |
-
primary_hue="
|
| 49 |
-
font=[gr.themes.GoogleFont("
|
|
|
|
| 50 |
)
|
| 51 |
|
| 52 |
with gr.Blocks(theme=theme) as demo:
|
| 53 |
-
gr.Markdown(
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
get_answer_button = gr.Button("
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
get_answer_button.click(
|
| 68 |
fn=get_answer,
|
|
|
|
| 2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
|
| 4 |
# --- Model Loading ---
|
|
|
|
|
|
|
| 5 |
model_id = "HAissa/EdNA"
|
| 6 |
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
|
| 7 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 8 |
|
| 9 |
def get_answer(question, ans_a, ans_b, ans_c, ans_d):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
options = f"A) {ans_a}\nB) {ans_b}\nC) {ans_c}\nD) {ans_d}"
|
| 11 |
prompt = f"Question: {question}\nOptions:\n{options}\nAnswer:"
|
| 12 |
|
|
|
|
| 16 |
outputs = model.generate(**inputs, max_new_tokens=3)
|
| 17 |
answer_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 18 |
|
|
|
|
| 19 |
try:
|
| 20 |
final_answer = answer_text.split("Answer:")[1].strip().split('\n')[0]
|
| 21 |
if final_answer.startswith("A)"):
|
|
|
|
| 28 |
return ans_d
|
| 29 |
else:
|
| 30 |
return final_answer
|
|
|
|
| 31 |
except IndexError:
|
| 32 |
final_answer = "Could not parse the model's answer."
|
| 33 |
|
|
|
|
| 35 |
|
| 36 |
# --- Gradio Interface ---
|
| 37 |
|
| 38 |
+
# Use a modern font (Poppins) and 'emerald' for a fresh green look.
|
| 39 |
+
# Increased radius_size gives components a friendlier, modern rounded look.
|
| 40 |
theme = gr.themes.Soft(
|
| 41 |
+
primary_hue="emerald",
|
| 42 |
+
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
|
| 43 |
+
radius_size="lg"
|
| 44 |
)
|
| 45 |
|
| 46 |
with gr.Blocks(theme=theme) as demo:
|
| 47 |
+
gr.Markdown(
|
| 48 |
+
"""
|
| 49 |
+
# 🤖 EdNA: MCQ Answering AI
|
| 50 |
+
Enter your question and options below, then click **Predict Answer** to see the model's choice.
|
| 51 |
+
"""
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
with gr.Group():
|
| 55 |
+
# Question input: larger, but limited max_lines to prevent excessive scrolling
|
| 56 |
+
question_input = gr.Textbox(
|
| 57 |
+
label="Question",
|
| 58 |
+
placeholder="Type the full question here...",
|
| 59 |
+
lines=3,
|
| 60 |
+
max_lines=5
|
| 61 |
+
)
|
| 62 |
|
| 63 |
+
# 2x2 Grid for a compact, modern MCQ layout
|
| 64 |
+
with gr.Row():
|
| 65 |
+
# setting max_lines=1 ensures these stay as single-line, non-scrollable input fields
|
| 66 |
+
answer_a_input = gr.Textbox(label="Option A", placeholder="Answer A", lines=1, max_lines=1)
|
| 67 |
+
answer_b_input = gr.Textbox(label="Option B", placeholder="Answer B", lines=1, max_lines=1)
|
| 68 |
+
with gr.Row():
|
| 69 |
+
answer_c_input = gr.Textbox(label="Option C", placeholder="Answer C", lines=1, max_lines=1)
|
| 70 |
+
answer_d_input = gr.Textbox(label="Option D", placeholder="Answer D", lines=1, max_lines=1)
|
| 71 |
|
| 72 |
+
# A larger, more prominent button
|
| 73 |
+
get_answer_button = gr.Button("✨ Predict Answer", variant="primary", size="lg")
|
| 74 |
+
|
| 75 |
+
# Distinct output box
|
| 76 |
+
final_answer_output = gr.Textbox(
|
| 77 |
+
label="Model Prediction",
|
| 78 |
+
interactive=False,
|
| 79 |
+
lines=2,
|
| 80 |
+
placeholder="The result will appear here..."
|
| 81 |
+
)
|
| 82 |
|
| 83 |
get_answer_button.click(
|
| 84 |
fn=get_answer,
|