Spaces:
Runtime error
Runtime error
File size: 4,782 Bytes
c5224d3 3195f7f c5224d3 3195f7f c5224d3 3195f7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import evaluate
import re
import matplotlib
matplotlib.use('Agg') # for non-interactive envs
import matplotlib.pyplot as plt
import io
import base64
# ---------------------------------------------------------------------------
# 1. Define model name and load model/tokenizer
# ---------------------------------------------------------------------------
model_name = "meta-llama/Llama-3.2-1B-Instruct" # fictional placeholder
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# ---------------------------------------------------------------------------
# 2. Define a tiny "dataset" for demonstration
# In reality, you'll load a real dataset from HF or custom code.
# ---------------------------------------------------------------------------
test_data = [
{"question": "What is 2+2?", "answer": "4"},
{"question": "What is 3*3?", "answer": "9"},
{"question": "What is 10/2?", "answer": "5"},
]
# ---------------------------------------------------------------------------
# 3. Load a metric (accuracy) from Hugging Face evaluate library
# ---------------------------------------------------------------------------
accuracy_metric = evaluate.load("accuracy")
# ---------------------------------------------------------------------------
# 4. Inference helper functions
# ---------------------------------------------------------------------------
def generate_answer(question):
"""
Generates an answer to the given question using the loaded model.
"""
# Simple prompt
prompt = f"Question: {question}\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=30,
temperature=0.0, # deterministic
)
text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text_output
def parse_answer(model_output):
"""
Heuristic to extract the final numeric answer from model's text.
You can customize this regex or logic as needed.
"""
# Example: find digits (possibly multiple, but we keep the first match)
match = re.search(r"(\d+)", model_output)
if match:
return match.group(1)
# fallback to entire text if no digits found
return model_output.strip()
# ---------------------------------------------------------------------------
# 5. Evaluation routine
# ---------------------------------------------------------------------------
def run_evaluation():
predictions = []
references = []
for sample in test_data:
question = sample["question"]
reference_answer = sample["answer"]
# Model inference
model_output = generate_answer(question)
predicted_answer = parse_answer(model_output)
predictions.append(predicted_answer)
references.append(reference_answer)
# Normalize answers (simple: just remove spaces/punctuation, lower case)
def normalize_answer(ans):
return ans.lower().strip()
norm_preds = [normalize_answer(p) for p in predictions]
norm_refs = [normalize_answer(r) for r in references]
# Compute accuracy
results = accuracy_metric.compute(predictions=norm_preds, references=norm_refs)
accuracy = results["accuracy"]
# Create a simple bar chart: correct vs. incorrect
correct_count = sum(p == r for p, r in zip(norm_preds, norm_refs))
incorrect_count = len(test_data) - correct_count
fig, ax = plt.subplots()
ax.bar(["Correct", "Incorrect"], [correct_count, incorrect_count], color=["green", "red"])
ax.set_title("Evaluation Results")
ax.set_ylabel("Count")
ax.set_ylim([0, len(test_data)])
# Convert the plot to a base64-encoded PNG for Gradio display
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
plt.close(fig)
data = base64.b64encode(buf.read()).decode("utf-8")
image_url = f"data:image/png;base64,{data}"
# Return text and the plot
return f"Accuracy: {accuracy:.2f}", image_url
# ---------------------------------------------------------------------------
# 6. Gradio App
# ---------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# Simple Math Evaluation with 'Llama 3.2'")
eval_button = gr.Button("Run Evaluation")
output_text = gr.Textbox(label="Results")
output_plot = gr.HTML(label="Plot")
eval_button.click(
fn=run_evaluation,
inputs=None,
outputs=[output_text, output_plot]
)
demo.launch()
|