Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| import json | |
| import os | |
| from typing import Dict, Union | |
| # --- Model and Instruction Configuration --- | |
| MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct" | |
| SYSTEM_INSTRUCTION = """ | |
| You are a strict grading assistant. | |
| Return ONLY a JSON object with: | |
| - accuracy (float 0-10) | |
| - grade (string A-D) | |
| - feedback (string) | |
| """ | |
| # ------------------------------------------ | |
| # Load Model and Tokenizer once for the entire application | |
| try: | |
| print(f"Loading model {MODEL_ID} for Gradio app...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| TERMINATORS = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
| ] | |
| MODEL_LOADED = True | |
| except Exception as e: | |
| print(f"Error loading model or tokenizer: {e}") | |
| print("Gradio will run, but the grading function will return an error.") | |
| MODEL_LOADED = False | |
| tokenizer, model, TERMINATORS = None, None, None | |
| def grade_response(student_response: str) -> Union[Dict, str]: | |
| """ | |
| Core grading function (same as before) | |
| """ | |
| if not MODEL_LOADED: | |
| return {"accuracy": 0.0, "grade": "Error", "feedback": "Model failed to load. Check console for details."} | |
| # 1. Construct the Message List | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_INSTRUCTION}, | |
| {"role": "user", "content": f"Student response to grade: '{student_response}'"}, | |
| ] | |
| # 2. Apply Chat Template and Tokenize | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| # 3. Generate the Output | |
| try: | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=200, | |
| eos_token_id=TERMINATORS, | |
| do_sample=True, | |
| temperature=0.5, | |
| top_p=0.9, | |
| ) | |
| except Exception as e: | |
| return {"accuracy": 0.0, "grade": "Error", "feedback": f"Generation error: {e}"} | |
| # 4. Decode the Raw Response | |
| raw_response = tokenizer.decode( | |
| output_ids[0][input_ids.shape[-1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| # 5. Parse the JSON Output | |
| try: | |
| start_index = raw_response.find('{') | |
| end_index = raw_response.rfind('}') + 1 | |
| json_string = raw_response[start_index:end_index] | |
| return json.loads(json_string) | |
| except json.JSONDecodeError: | |
| # If parsing fails, return a structured error response | |
| return {"accuracy": 0.0, "grade": "Error", "feedback": f"JSON Decode Error. Raw: {raw_response[:200]}..."} | |
| # --- Gradio Wrapper Function --- | |
| def gradio_grade_wrapper(student_response: str) -> tuple[float, str, str]: | |
| """ | |
| Wraps the core grading function to match the required Gradio outputs. | |
| """ | |
| result = grade_response(student_response) | |
| # Check if the result is a dictionary (the expected structured output) | |
| if isinstance(result, dict): | |
| # Gradio outputs: (accuracy, grade, feedback) | |
| return ( | |
| result.get("accuracy", 0.0), | |
| result.get("grade", "N/A"), | |
| result.get("feedback", "No feedback generated.") | |
| ) | |
| else: | |
| # Should not happen if error handling in grade_response is correct, | |
| # but here for extreme robustness. | |
| return (0.0, "ERROR", str(result)) | |
| # --- Gradio Interface Definition --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="LLM Essay Grader") as demo: | |
| gr.Markdown("# 📝 LLM Essay Grading Assistant (Llama-3.2-1B-Instruct)") | |
| gr.Markdown( | |
| "Enter a student's response below to receive an automated grade, accuracy score, and feedback " | |
| "from the Llama-3.2-1B-Instruct model." | |
| ) | |
| # Input Component | |
| with gr.Row(): | |
| student_input = gr.Textbox( | |
| label="Student Response to Grade", | |
| placeholder="E.g., 'The main causes of the World War 2 were economic depression and poor leadership.'", | |
| lines=5, | |
| scale=3 | |
| ) | |
| grade_button = gr.Button("Submit for Grading", scale=1, variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown("## Grading Results") | |
| # Output Components arranged in a Row for visual clarity | |
| with gr.Row(): | |
| accuracy_output = gr.Number(label="Accuracy (0-10)", interactive=False, precision=1) | |
| grade_output = gr.Textbox(label="Grade (A-D)", interactive=False) | |
| feedback_output = gr.Textbox( | |
| label="Detailed Feedback", | |
| interactive=False, | |
| lines=4, | |
| max_lines=10 | |
| ) | |
| # Event Listener: Connect the button click to the wrapper function | |
| grade_button.click( | |
| fn=gradio_grade_wrapper, | |
| inputs=[student_input], | |
| outputs=[accuracy_output, grade_output, feedback_output] | |
| ) | |
| # Add Examples | |
| gr.Examples( | |
| examples=[ | |
| ["The Earth is a cube and its main moon is Mars, which proves that gravity is fake."], | |
| ["A proper noun is a name used to designate a single, specific person, place, or thing, and is always capitalized."], | |
| ["The two main drivers of climate change are the burning of fossil fuels (releasing greenhouse gases) and deforestation."], | |
| ], | |
| inputs=student_input, | |
| ) | |
| # Launch the Gradio App | |
| if __name__ == "__main__": | |
| demo.launch() |