GradGradio / app.py
redc007's picture
Create app.py
9c4c110 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import os
from typing import Dict, Union
# --- Model and Instruction Configuration ---
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"
SYSTEM_INSTRUCTION = """
You are a strict grading assistant.
Return ONLY a JSON object with:
- accuracy (float 0-10)
- grade (string A-D)
- feedback (string)
"""
# ------------------------------------------
# Load Model and Tokenizer once for the entire application
try:
print(f"Loading model {MODEL_ID} for Gradio app...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
TERMINATORS = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
MODEL_LOADED = True
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
print("Gradio will run, but the grading function will return an error.")
MODEL_LOADED = False
tokenizer, model, TERMINATORS = None, None, None
def grade_response(student_response: str) -> Union[Dict, str]:
"""
Core grading function (same as before)
"""
if not MODEL_LOADED:
return {"accuracy": 0.0, "grade": "Error", "feedback": "Model failed to load. Check console for details."}
# 1. Construct the Message List
messages = [
{"role": "system", "content": SYSTEM_INSTRUCTION},
{"role": "user", "content": f"Student response to grade: '{student_response}'"},
]
# 2. Apply Chat Template and Tokenize
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# 3. Generate the Output
try:
output_ids = model.generate(
input_ids,
max_new_tokens=200,
eos_token_id=TERMINATORS,
do_sample=True,
temperature=0.5,
top_p=0.9,
)
except Exception as e:
return {"accuracy": 0.0, "grade": "Error", "feedback": f"Generation error: {e}"}
# 4. Decode the Raw Response
raw_response = tokenizer.decode(
output_ids[0][input_ids.shape[-1]:],
skip_special_tokens=True
).strip()
# 5. Parse the JSON Output
try:
start_index = raw_response.find('{')
end_index = raw_response.rfind('}') + 1
json_string = raw_response[start_index:end_index]
return json.loads(json_string)
except json.JSONDecodeError:
# If parsing fails, return a structured error response
return {"accuracy": 0.0, "grade": "Error", "feedback": f"JSON Decode Error. Raw: {raw_response[:200]}..."}
# --- Gradio Wrapper Function ---
def gradio_grade_wrapper(student_response: str) -> tuple[float, str, str]:
"""
Wraps the core grading function to match the required Gradio outputs.
"""
result = grade_response(student_response)
# Check if the result is a dictionary (the expected structured output)
if isinstance(result, dict):
# Gradio outputs: (accuracy, grade, feedback)
return (
result.get("accuracy", 0.0),
result.get("grade", "N/A"),
result.get("feedback", "No feedback generated.")
)
else:
# Should not happen if error handling in grade_response is correct,
# but here for extreme robustness.
return (0.0, "ERROR", str(result))
# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft(), title="LLM Essay Grader") as demo:
gr.Markdown("# 📝 LLM Essay Grading Assistant (Llama-3.2-1B-Instruct)")
gr.Markdown(
"Enter a student's response below to receive an automated grade, accuracy score, and feedback "
"from the Llama-3.2-1B-Instruct model."
)
# Input Component
with gr.Row():
student_input = gr.Textbox(
label="Student Response to Grade",
placeholder="E.g., 'The main causes of the World War 2 were economic depression and poor leadership.'",
lines=5,
scale=3
)
grade_button = gr.Button("Submit for Grading", scale=1, variant="primary")
gr.Markdown("---")
gr.Markdown("## Grading Results")
# Output Components arranged in a Row for visual clarity
with gr.Row():
accuracy_output = gr.Number(label="Accuracy (0-10)", interactive=False, precision=1)
grade_output = gr.Textbox(label="Grade (A-D)", interactive=False)
feedback_output = gr.Textbox(
label="Detailed Feedback",
interactive=False,
lines=4,
max_lines=10
)
# Event Listener: Connect the button click to the wrapper function
grade_button.click(
fn=gradio_grade_wrapper,
inputs=[student_input],
outputs=[accuracy_output, grade_output, feedback_output]
)
# Add Examples
gr.Examples(
examples=[
["The Earth is a cube and its main moon is Mars, which proves that gravity is fake."],
["A proper noun is a name used to designate a single, specific person, place, or thing, and is always capitalized."],
["The two main drivers of climate change are the burning of fossil fuels (releasing greenhouse gases) and deforestation."],
],
inputs=student_input,
)
# Launch the Gradio App
if __name__ == "__main__":
demo.launch()