Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # --- Model Loading --- | |
| model_id = "HAissa/EdNA" | |
| model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| def get_answer(question, ans_a, ans_b, ans_c, ans_d): | |
| options = f"A) {ans_a}\nB) {ans_b}\nC) {ans_c}\nD) {ans_d}" | |
| prompt = f"Question: {question}\nOptions:\n{options}\nAnswer:" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| # Generate the answer | |
| outputs = model.generate(**inputs, max_new_tokens=3) | |
| answer_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| try: | |
| final_answer = answer_text.split("Answer:")[1].strip().split('\n')[0] | |
| if final_answer.startswith("A)"): | |
| return ans_a | |
| elif final_answer.startswith("B)"): | |
| return ans_b | |
| elif final_answer.startswith("C)"): | |
| return ans_c | |
| elif final_answer.startswith("D)"): | |
| return ans_d | |
| else: | |
| return final_answer | |
| except IndexError: | |
| final_answer = "Could not parse the model's answer." | |
| return final_answer | |
| # --- Gradio Interface --- | |
| # Use a modern font (Poppins) and 'emerald' for a fresh green look. | |
| # Increased radius_size gives components a friendlier, modern rounded look. | |
| theme = gr.themes.Soft( | |
| primary_hue="emerald", | |
| font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| radius_size="lg" | |
| ) | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.Markdown( | |
| """ | |
| # π€ EdNA: MCQ Answering AI | |
| Enter your question and options below, then click **Predict Answer** to see the model's choice. | |
| """ | |
| ) | |
| with gr.Group(): | |
| # Question input: larger, but limited max_lines to prevent excessive scrolling | |
| question_input = gr.Textbox( | |
| label="Question", | |
| placeholder="Type the full question here...", | |
| lines=3, | |
| max_lines=5 | |
| ) | |
| # 2x2 Grid for a compact, modern MCQ layout | |
| with gr.Row(): | |
| # setting max_lines=1 ensures these stay as single-line, non-scrollable input fields | |
| answer_a_input = gr.Textbox(label = "", placeholder="Answer A", lines=1, max_lines=1) | |
| answer_b_input = gr.Textbox(label = "", placeholder="Answer B", lines=1, max_lines=1) | |
| with gr.Row(): | |
| answer_c_input = gr.Textbox(label = "", placeholder="Answer C", lines=1, max_lines=1) | |
| answer_d_input = gr.Textbox(label = "", placeholder="Answer D", lines=1, max_lines=1) | |
| # A larger, more prominent button | |
| get_answer_button = gr.Button("β¨ Predict Answer", variant="primary", size="lg") | |
| # Distinct output box | |
| final_answer_output = gr.Textbox( | |
| label="Model Prediction", | |
| interactive=False, | |
| lines=2, | |
| placeholder="The result will appear here..." | |
| ) | |
| get_answer_button.click( | |
| fn=get_answer, | |
| inputs=[question_input, answer_a_input, answer_b_input, answer_c_input, answer_d_input], | |
| outputs=final_answer_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |